diff --git a/agentflow/core/graph/agent_internal/google.py b/agentflow/core/graph/agent_internal/google.py index 08ee6cf..6304c35 100644 --- a/agentflow/core/graph/agent_internal/google.py +++ b/agentflow/core/graph/agent_internal/google.py @@ -304,9 +304,6 @@ def _build_google_config( structured_output = getattr(self, "output_schema", None) is not None text_like_output = self.output_type in ("text", "json") - if system_instruction: - config_kwargs["system_instruction"] = system_instruction - if "temperature" in call_kwargs: config_kwargs["temperature"] = call_kwargs.pop("temperature") if "max_tokens" in call_kwargs or "max_output_tokens" in call_kwargs: @@ -315,6 +312,13 @@ def _build_google_config( call_kwargs.pop("max_output_tokens", None), ) + cached_content = call_kwargs.pop("cached_content", None) + if cached_content: + # system_instruction is already inside the cache — don't resend it + config_kwargs["cached_content"] = cached_content + elif system_instruction: + config_kwargs["system_instruction"] = system_instruction + if tools and text_like_output and not structured_output: function_declarations = self._convert_tools_to_google_format(tools) if function_declarations: @@ -370,11 +374,17 @@ async def _call_google_content_generation( mode_suffix, self.model, ) - return await self.client.aio.models.generate_content( + response = await self.client.aio.models.generate_content( model=self.model, contents=google_contents, config=config, ) + cached = ( + getattr(getattr(response, "usage_metadata", None), "cached_content_token_count", 0) or 0 + ) + if cached: + logger.debug("Cache hit: %d cached tokens (Google)", cached) + return response async def _call_google( self, @@ -397,9 +407,24 @@ async def _call_google( call_kwargs = {**self.llm_kwargs, **kwargs} + # Peek before _build_google_config pops it, so we know whether a cache is active. + has_explicit_cache = bool(call_kwargs.get("cached_content")) + system_instruction, google_contents = self._convert_to_google_format(messages) config = self._build_google_config(system_instruction, tools, call_kwargs) + # When an explicit cache is active, system_instruction is excluded from the + # config (the static part lives inside the cache). Any dynamic additions + # — memory injections, skill prompts, per-request state — are preserved by + # prepending them as a leading user message so the model still sees them. + if has_explicit_cache and system_instruction: + from google.genai import types + + google_contents = [ + types.Content(role="user", parts=[types.Part(text=system_instruction)]), + *google_contents, + ] + if structured_output: if config is None: from google.genai import types diff --git a/agentflow/core/graph/agent_internal/openai.py b/agentflow/core/graph/agent_internal/openai.py index 4686fec..605f353 100644 --- a/agentflow/core/graph/agent_internal/openai.py +++ b/agentflow/core/graph/agent_internal/openai.py @@ -101,25 +101,36 @@ async def _call_openai( call_kwargs["tools"] = tools logger.debug("Calling OpenAI beta.chat.completions.parse with model=%s", self.model) - return await self.client.beta.chat.completions.parse( + response = await self.client.beta.chat.completions.parse( model=self.model, messages=messages, response_format=output_schema, stream=False, **call_kwargs, ) + details = getattr(getattr(response, "usage", None), "prompt_tokens_details", None) + cached = getattr(details, "cached_tokens", 0) or 0 + if cached: + logger.debug("Cache hit: %d cached tokens (OpenAI chat completions)", cached) + return response if self.output_type in ("text", "json"): if tools: call_kwargs["tools"] = tools logger.debug("Calling OpenAI chat.completions.create with model=%s", self.model) - return await self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model, messages=messages, stream=stream, **call_kwargs, ) + if not stream: + details = getattr(getattr(response, "usage", None), "prompt_tokens_details", None) + cached = getattr(details, "cached_tokens", 0) or 0 + if cached: + logger.debug("Cache hit: %d cached tokens (OpenAI chat completions)", cached) + return response if self.output_type == "image": prompt = self._extract_prompt(messages) @@ -223,15 +234,21 @@ async def _call_openai_responses( # noqa: PLR0912 call_kwargs["instructions"] = instructions if responses_tools: call_kwargs["tools"] = responses_tools - if self.reasoning_config: + if self.reasoning_config: # type: ignore call_kwargs["reasoning"] = self.reasoning_config call_kwargs.pop("reasoning_effort", None) logger.debug("Calling OpenAI responses.create with model=%s", self.model) - return await self.client.responses.create( + response = await self.client.responses.create( model=self.model, input=input_items, stream=stream, **call_kwargs, ) + if not stream: + details = getattr(getattr(response, "usage", None), "input_tokens_details", None) + cached = getattr(details, "cached_tokens", 0) or 0 + if cached: + logger.debug("Cache hit: %d cached tokens (OpenAI responses API)", cached) + return response diff --git a/agentflow/core/llm/__init__.py b/agentflow/core/llm/__init__.py index c13c90b..e2b2178 100644 --- a/agentflow/core/llm/__init__.py +++ b/agentflow/core/llm/__init__.py @@ -1,6 +1,7 @@ """LLM client creation utilities shared across agents and evaluators.""" +from .caller import call_llm from .client_factory import create_llm_client, detect_provider -__all__ = ["create_llm_client", "detect_provider"] +__all__ = ["call_llm", "create_llm_client", "detect_provider"] diff --git a/agentflow/core/llm/caller.py b/agentflow/core/llm/caller.py new file mode 100644 index 0000000..1acb2b5 --- /dev/null +++ b/agentflow/core/llm/caller.py @@ -0,0 +1,275 @@ +"""Shared single-turn LLM call utility. + +Centralises the Google / OpenAI dispatch that was duplicated across +SummaryContextManager, LLMCallerMixin (eval judge), and UserSimulator. + +Returns a plain 4-tuple so callers can choose how much they consume: + + text, *_ = await call_llm(...) # only text + text, inp, out, cache = await call_llm(...) # text + token counts + +OpenAI supports two API styles: +- ``"responses"`` (default) — ``client.responses.create()``, the current + recommended API with ``input`` / ``instructions`` / ``max_output_tokens``. +- ``"chat"`` — ``client.chat.completions.create()``, the legacy style + required by older or third-party-hosted models (e.g. some Chinese models). + +The Agent class keeps its own execution path (streaming, tools, retry, etc.) +and is unaffected by this module. +""" + +from __future__ import annotations + +import logging +from typing import Any, Literal + +from agentflow.core.llm.client_factory import create_llm_client, detect_provider + + +logger = logging.getLogger("agentflow.llm.caller") + + +async def call_llm( + model: str, + prompt: str, + *, + system_prompt: str | None = None, + max_tokens: int = 1024, + temperature: float = 0.3, + json_mode: bool = False, + use_vertex_ai: bool = False, + api_style: Literal["responses", "chat"] = "responses", + **llm_kwargs: Any, +) -> tuple[str, int, int, int]: + """Single-turn LLM call with provider auto-detection. + + Args: + model: Model identifier (e.g. ``"gemini-2.0-flash"``, ``"gpt-4o-mini"``). + Provider is inferred from the name via ``detect_provider``. + prompt: The user-turn content to send. + system_prompt: Optional system instruction prepended to the request. + max_tokens: Maximum tokens to generate. + temperature: Sampling temperature. + json_mode: When ``True``, instructs the provider to return valid JSON. + use_vertex_ai: When ``True``, force Google Vertex AI client. + api_style: OpenAI only. ``"responses"`` (default) uses the current + Responses API (``client.responses.create``). Use ``"chat"`` for + models that only support the legacy Chat Completions endpoint + (e.g. older or self-hosted Chinese models). + **llm_kwargs: Provider-specific parameters forwarded directly to the + underlying API call. Examples: + + - Google: ``cached_content="cachedContents/abc123"`` — attaches an + explicit Gemini context cache created via the Google SDK. + - OpenAI: ``prompt_cache_key="my-agent-v1"`` — improves cache hit + rates across requests sharing the same long system-prompt prefix. + - OpenAI: ``prompt_cache_retention="24h"`` — extends cache retention + for gpt-5.5+ models (default is in-memory, ~5-10 min). + + Returns: + ``(text, input_tokens, output_tokens, cache_read_tokens)`` — plain tuple. + Token counts are 0 when the provider does not report them. + """ + provider = detect_provider(model, use_vertex_ai=use_vertex_ai) + client = create_llm_client(provider, use_vertex_ai=use_vertex_ai) + + if provider == "google": + return await _call_google( + client, + model, + prompt, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + json_mode=json_mode, + **llm_kwargs, + ) + + if api_style == "chat": + return await _call_openai_chat( + client, + model, + prompt, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + json_mode=json_mode, + **llm_kwargs, + ) + return await _call_openai_responses( + client, + model, + prompt, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + json_mode=json_mode, + **llm_kwargs, + ) + + +# --------------------------------------------------------------------------- +# Google +# --------------------------------------------------------------------------- + + +async def _call_google( + client: Any, + model: str, + prompt: str, + *, + system_prompt: str | None, + max_tokens: int, + temperature: float, + json_mode: bool, + **llm_kwargs: Any, +) -> tuple[str, int, int, int]: + from google.genai import types + + config_kwargs: dict[str, Any] = { + "max_output_tokens": max_tokens, + "temperature": temperature, + } + if json_mode: + config_kwargs["response_mime_type"] = "application/json" + + cached_content = llm_kwargs.pop("cached_content", None) + if cached_content: + # system_prompt is already inside the cache — don't resend it + config_kwargs["cached_content"] = cached_content + elif system_prompt: + config_kwargs["system_instruction"] = system_prompt + + response = await client.aio.models.generate_content( + model=model, + contents=prompt, + config=types.GenerateContentConfig(**config_kwargs), + ) + + text = (response.text or "").strip() + inp = out = cache = 0 + meta = getattr(response, "usage_metadata", None) + if meta is not None: + inp = getattr(meta, "prompt_token_count", 0) or 0 + out = getattr(meta, "candidates_token_count", 0) or 0 + cache = getattr(meta, "cached_content_token_count", 0) or 0 + + if cache: + logger.debug("Cache hit: %d cached tokens (Google)", cache) + + return text, inp, out, cache + + +# --------------------------------------------------------------------------- +# OpenAI — Responses API (default) +# --------------------------------------------------------------------------- + + +async def _call_openai_responses( + client: Any, + model: str, + prompt: str, + *, + system_prompt: str | None, + max_tokens: int, + temperature: float, + json_mode: bool, + **llm_kwargs: Any, +) -> tuple[str, int, int, int]: + """Call the OpenAI Responses API (client.responses.create).""" + kwargs: dict[str, Any] = { + "model": model, + "input": prompt, + "max_output_tokens": max_tokens, + "temperature": temperature, + } + if system_prompt: + kwargs["instructions"] = system_prompt + if json_mode: + kwargs["text"] = {"format": {"type": "json_object"}} + kwargs.update(llm_kwargs) + + response = await client.responses.create(**kwargs) + + text = _extract_responses_text(response) + inp = out = cache = 0 + usage = getattr(response, "usage", None) + if usage is not None: + inp = getattr(usage, "input_tokens", 0) or 0 + out = getattr(usage, "output_tokens", 0) or 0 + details = getattr(usage, "input_tokens_details", None) + if details is not None: + cache = getattr(details, "cached_tokens", 0) or 0 + + if cache: + logger.debug("Cache hit: %d cached tokens (OpenAI responses API)", cache) + + return text, inp, out, cache + + +def _extract_responses_text(response: Any) -> str: + """Extract the assistant text from an OpenAI Responses API response object.""" + # SDK convenience property available in openai >= 1.61 + output_text = getattr(response, "output_text", None) + if output_text is not None: + return str(output_text).strip() + + # Manual fallback: iterate output items + for item in getattr(response, "output", []): + item_type = getattr(item, "type", None) + if item_type == "message": + for part in getattr(item, "content", []): + if getattr(part, "type", None) == "output_text": + return (getattr(part, "text", "") or "").strip() + + return "" + + +# --------------------------------------------------------------------------- +# OpenAI — Chat Completions (legacy / compat) +# --------------------------------------------------------------------------- + + +async def _call_openai_chat( + client: Any, + model: str, + prompt: str, + *, + system_prompt: str | None, + max_tokens: int, + temperature: float, + json_mode: bool, + **llm_kwargs: Any, +) -> tuple[str, int, int, int]: + """Call the OpenAI Chat Completions API (client.chat.completions.create).""" + messages: list[dict[str, str]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + kwargs: dict[str, Any] = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + } + if json_mode: + kwargs["response_format"] = {"type": "json_object"} + kwargs.update(llm_kwargs) + + response = await client.chat.completions.create(**kwargs) + + text = (response.choices[0].message.content or "").strip() + inp = out = cache = 0 + usage = getattr(response, "usage", None) + if usage is not None: + inp = getattr(usage, "prompt_tokens", 0) or 0 + out = getattr(usage, "completion_tokens", 0) or 0 + details = getattr(usage, "prompt_tokens_details", None) + if details is not None: + cache = getattr(details, "cached_tokens", 0) or 0 + + if cache: + logger.debug("Cache hit: %d cached tokens (OpenAI chat completions)", cache) + + return text, inp, out, cache diff --git a/agentflow/core/state/__init__.py b/agentflow/core/state/__init__.py index 69859c6..2a00dc4 100644 --- a/agentflow/core/state/__init__.py +++ b/agentflow/core/state/__init__.py @@ -39,6 +39,7 @@ ) from .stream_chunks import StreamChunk, StreamEvent from .stream_emitter import StreamEmitter +from .summary_context_manager import SummaryContextManager from .tool_result import ToolResult @@ -62,6 +63,7 @@ "StreamChunk", "StreamEmitter", "StreamEvent", + "SummaryContextManager", "TextBlock", "TextBlock", "TokenUsages", diff --git a/agentflow/core/state/summary_context_manager.py b/agentflow/core/state/summary_context_manager.py new file mode 100644 index 0000000..71a1507 --- /dev/null +++ b/agentflow/core/state/summary_context_manager.py @@ -0,0 +1,266 @@ +"""SummaryContextManager — LLM-backed context summarization with token-budget support. + +When the conversation context grows beyond ``max_messages`` or exceeds the +``token_budget`` estimate, the oldest messages are summarised by an LLM call +and replaced with a concise text stored in ``state.context_summary``. The +most recent ``keep_recent`` messages are kept verbatim so the agent always has +immediate context. + +``convert_messages`` already injects ``state.context_summary`` as an assistant +message before the retained context, so no changes to the execution path are +required. + +Provider auto-detection follows the same rules as the rest of the framework: +``gemini-*`` / ``imagen-*`` → Google GenAI; ``gpt-*`` / ``o1-*`` etc. → OpenAI. +Any model string supported by ``detect_provider`` works here. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Literal, TypeVar + +from agentflow.core.llm.caller import call_llm +from agentflow.core.llm.client_factory import detect_provider +from agentflow.core.state.agent_state import AgentState +from agentflow.core.state.base_context import BaseContextManager +from agentflow.core.state.message import Message +from agentflow.core.state.reducers import remove_tool_messages + + +S = TypeVar("S", bound=AgentState) + +logger = logging.getLogger("agentflow.state.summary") + +_DEFAULT_SUMMARY_PROMPT = ( + "You are a conversation summarizer. " + "Summarize the following conversation history concisely, preserving all important facts, " + "decisions, tool results, and context needed to continue the conversation. " + "Write in third-person past tense. Be factual and specific. " + "Do not add commentary or explanations about the summary itself." +) + + +def _estimate_tokens(messages: list[Message]) -> int: + """Rough token estimate: 1 token ≈ 4 characters.""" + total_chars = 0 + for msg in messages: + for block in msg.content: + text = getattr(block, "text", None) + if text: + total_chars += len(text) + if msg.tools_calls: + total_chars += len(str(msg.tools_calls)) + return max(1, total_chars // 4) + + +def _messages_to_text(messages: list[Message]) -> str: + """Render messages as a readable block for the summarizer.""" + parts: list[str] = [] + for msg in messages: + role = msg.role.upper() + text_parts: list[str] = [] + for block in msg.content: + t = getattr(block, "text", None) + if t: + text_parts.append(t.strip()) + if msg.tools_calls: + text_parts.append(f"[Tool calls: {msg.tools_calls}]") + if text_parts: + parts.append(f"{role}: {' '.join(text_parts)}") + return "\n".join(parts) + + +class SummaryContextManager(BaseContextManager[S]): + """Context manager that compresses old messages into an LLM-generated summary. + + Summarisation is triggered when *either* threshold is exceeded: + + * ``max_messages``: total message count in ``state.context`` + * ``token_budget``: estimated token count of the context + + After summarisation the oldest messages are removed and the generated text + is stored in ``state.context_summary``. Subsequent summarisations append + to the existing summary so no historical information is permanently lost. + + Args: + model: Model identifier for summarisation (e.g. ``"gemini-2.0-flash"``, + ``"gpt-4o-mini"``). Provider is auto-detected from the name. + max_messages: Trigger summarisation when message count exceeds this. + ``None`` disables the count-based trigger. + token_budget: Trigger summarisation when estimated token count exceeds + this. ``None`` disables the token-budget trigger. + keep_recent: Number of most-recent non-system messages to retain + verbatim after summarisation. + remove_tool_msgs: When ``True``, strip tool-call and tool-result messages + from the context before checking thresholds and before summarising. + Mirrors the behaviour of ``MessageContextManager``. + summary_system_prompt: Override the default summarisation instruction + sent to the LLM. + max_summary_tokens: Upper bound on the summary output length (tokens). + api_style: OpenAI only. ``"responses"`` (default) uses the Responses API. + Use ``"chat"`` for models that only support the legacy Chat Completions + endpoint (e.g. older or third-party-hosted Chinese models). + + Example:: + + from agentflow.core.state import SummaryContextManager + + manager = SummaryContextManager( + model="gemini-2.0-flash", + token_budget=6000, + keep_recent=6, + ) + app = agent.compile(context_manager=manager) + + # Or trigger on message count instead: + manager = SummaryContextManager( + model="gpt-4o-mini", + max_messages=30, + token_budget=8000, # either threshold fires summarisation + keep_recent=8, + ) + """ + + def __init__( + self, + model: str, + *, + max_messages: int | None = 30, + token_budget: int | None = None, + keep_recent: int = 8, + remove_tool_msgs: bool = False, + summary_system_prompt: str | None = None, + max_summary_tokens: int = 600, + api_style: Literal["responses", "chat"] = "responses", + ) -> None: + self.model = model + self.max_messages = max_messages + self.token_budget = token_budget + self.keep_recent = keep_recent + self.remove_tool_msgs = remove_tool_msgs + self.summary_system_prompt = summary_system_prompt or _DEFAULT_SUMMARY_PROMPT + self.max_summary_tokens = max_summary_tokens + self.api_style = api_style + + self._provider: str = detect_provider(model) + + # ------------------------------------------------------------------ + # Trigger logic + # ------------------------------------------------------------------ + + def _should_summarize(self, messages: list[Message]) -> bool: + """Return True if any configured threshold is exceeded.""" + if self.max_messages is not None and len(messages) > self.max_messages: + logger.debug( + "Summarisation triggered by message count: %d > %d", + len(messages), + self.max_messages, + ) + return True + if self.token_budget is not None: + estimated = _estimate_tokens(messages) + if estimated > self.token_budget: + logger.debug( + "Summarisation triggered by token budget: ~%d tokens > %d budget", + estimated, + self.token_budget, + ) + return True + return False + + # ------------------------------------------------------------------ + # Context split + # ------------------------------------------------------------------ + + def _split_context(self, messages: list[Message]) -> tuple[list[Message], list[Message]]: + """Split into (messages_to_summarise, messages_to_keep). + + System messages are always kept and never summarised. + The most recent ``keep_recent`` non-system messages are kept verbatim. + """ + system_msgs = [m for m in messages if m.role == "system"] + non_system = [m for m in messages if m.role != "system"] + + if len(non_system) <= self.keep_recent: + return [], messages + + split_at = len(non_system) - self.keep_recent + to_summarize = non_system[:split_at] + to_keep = non_system[split_at:] + return to_summarize, system_msgs + to_keep + + # ------------------------------------------------------------------ + # LLM calls + # ------------------------------------------------------------------ + + async def _summarize(self, messages: list[Message]) -> str: + text = _messages_to_text(messages) + if not text.strip(): + return "" + summary, *_ = await call_llm( + self.model, + f"Conversation to summarize:\n\n{text}", + system_prompt=self.summary_system_prompt, + max_tokens=self.max_summary_tokens, + api_style=self.api_style, + ) + return summary + + # ------------------------------------------------------------------ + # BaseContextManager interface + # ------------------------------------------------------------------ + + async def atrim_context(self, state: S) -> S: + messages = state.context + if not messages: + return state + + if self.remove_tool_msgs: + messages = remove_tool_messages(messages) + logger.debug("Removed tool messages; %d messages remaining", len(messages)) + + if not self._should_summarize(messages): + # If tool messages were stripped, still commit the cleaned list + if self.remove_tool_msgs: + state.context = messages + return state + + to_summarize, remaining = self._split_context(messages) + if not to_summarize: + return state + + logger.debug( + "Summarising %d messages; keeping %d recent (provider=%s, model=%s)", + len(to_summarize), + len(remaining), + self._provider, + self.model, + ) + + try: + new_summary = await self._summarize(to_summarize) + except Exception: + logger.exception("Summarisation LLM call failed; leaving context unchanged") + return state + + if not new_summary: + return state + + # Rolling summary: append to any previously stored summary + if state.context_summary: + state.context_summary = state.context_summary + "\n\n" + new_summary + else: + state.context_summary = new_summary + + state.context = remaining + logger.debug( + "Context reduced to %d messages; cumulative summary length=%d chars", + len(remaining), + len(state.context_summary), + ) + return state + + def trim_context(self, state: S) -> S: + return asyncio.run(self.atrim_context(state)) diff --git a/agentflow/prebuilt/__init__.py b/agentflow/prebuilt/__init__.py index 0619707..d09feb3 100644 --- a/agentflow/prebuilt/__init__.py +++ b/agentflow/prebuilt/__init__.py @@ -6,6 +6,10 @@ from __future__ import annotations +# Context managers +from agentflow.core.state.message_context_manager import MessageContextManager +from agentflow.core.state.summary_context_manager import SummaryContextManager + # Agents from .agent import ( BaseReranker, @@ -43,10 +47,13 @@ "BaseReranker", "CohereReranker", "CrossEncoderReranker", + # Context managers + "MessageContextManager", "PlanActReflectAgent", "RAGAgent", "ReactAgent", "StructuredOutputAgent", + "SummaryContextManager", "SupervisorTeamAgent", "SwarmAgent", "SwarmMemberConfig", diff --git a/agentflow/qa/evaluation/config/criterion_config.py b/agentflow/qa/evaluation/config/criterion_config.py index e272b06..4c32671 100644 --- a/agentflow/qa/evaluation/config/criterion_config.py +++ b/agentflow/qa/evaluation/config/criterion_config.py @@ -47,6 +47,9 @@ class CriterionConfig(BaseModel): keywords: Required keywords for ContainsKeywordsCriterion. check_args: Whether to check tool arguments in trajectory matching. enabled: Whether this criterion is enabled. + api_style: OpenAI API style — ``"responses"`` (default, Responses API) + or ``"chat"`` (legacy Chat Completions, required for some older + or third-party-hosted models such as Chinese LLMs). """ threshold: float = 0.8 @@ -57,6 +60,7 @@ class CriterionConfig(BaseModel): keywords: list[str] = Field(default_factory=list) check_args: bool = False enabled: bool = True + api_style: str = "responses" @classmethod def tool_name_match(cls, threshold: float = 1.0) -> CriterionConfig: diff --git a/agentflow/qa/evaluation/config/reporter_config.py b/agentflow/qa/evaluation/config/reporter_config.py index bea1a95..d77ecde 100644 --- a/agentflow/qa/evaluation/config/reporter_config.py +++ b/agentflow/qa/evaluation/config/reporter_config.py @@ -20,6 +20,8 @@ class UserSimulatorConfig(BaseModel): temperature: Temperature for generation. thinking_enabled: Whether to enable thinking/reasoning. thinking_budget: Token budget for thinking (if enabled). + api_style: OpenAI API style — ``"responses"`` (default) or ``"chat"`` + for models that only support the legacy Chat Completions endpoint. """ model: str = "gemini-2.5-flash" @@ -27,6 +29,7 @@ class UserSimulatorConfig(BaseModel): temperature: float = 0.7 thinking_enabled: bool = False thinking_budget: int = 10240 + api_style: str = "responses" class ReporterConfig(BaseModel): diff --git a/agentflow/qa/evaluation/criteria/llm_utils.py b/agentflow/qa/evaluation/criteria/llm_utils.py index 06da39c..cea289d 100644 --- a/agentflow/qa/evaluation/criteria/llm_utils.py +++ b/agentflow/qa/evaluation/criteria/llm_utils.py @@ -4,8 +4,8 @@ Provides LLMCallerMixin with _call_llm_score() used by LLMJudgeCriterion, RubricBasedCriterion, and others. -Client creation is delegated to agentflow.core.llm.client_factory so that -the eval judge uses the same provider/Vertex AI logic as the Agent class. +Delegates to agentflow.core.llm.caller.call_llm so that all single-turn +provider dispatch lives in one place. """ from __future__ import annotations @@ -13,7 +13,8 @@ import json import logging -from agentflow.core.llm.client_factory import create_llm_client, detect_provider +from agentflow.core.llm.caller import call_llm +from agentflow.core.llm.client_factory import detect_provider from agentflow.qa.evaluation.token_usage import TokenUsage @@ -21,10 +22,7 @@ def _parse_model_provider(model: str, use_vertex_ai: bool = False) -> tuple[str, str]: - """Return (provider, model_name_without_prefix). - - Combines provider detection and prefix stripping in one call. - """ + """Return (provider, model_name_without_prefix).""" provider = detect_provider(model, use_vertex_ai=use_vertex_ai) model_name = model.split("/", 1)[-1] if "/" in model else model return provider, model_name @@ -34,111 +32,69 @@ class LLMCallerMixin: """Mixin providing shared LLM calling logic for LLM-based criteria. Reads ``self.config.judge_model`` to determine provider and model. - Supports the same provider/Vertex AI path as the Agent class via the - shared ``create_llm_client`` factory. """ - async def _call_llm_json(self, prompt: str) -> tuple[dict, TokenUsage]: - """Call the judge LLM and return (parsed JSON dict, token usage). - - Args: - prompt: The evaluation prompt to send. - - Returns: - Tuple of (parsed JSON dict, TokenUsage for this call). - """ + async def _call_google_json(self, prompt: str) -> tuple[dict, TokenUsage] | None: + """Call Google LLM and return (parsed JSON dict, token usage), or None if unavailable.""" judge_model: str = self.config.judge_model # type: ignore[attr-defined] use_vertex_ai: bool = getattr(self.config, "use_vertex_ai", False) # type: ignore[attr-defined] - - provider, model_name = _parse_model_provider(judge_model, use_vertex_ai=use_vertex_ai) - - try: - client = create_llm_client(provider, use_vertex_ai=use_vertex_ai) - except (ImportError, ValueError) as exc: - logger.warning("Could not create %s client: %s", provider, exc) - return {"score": 0.5, "reasoning": f"LLM client unavailable: {exc}"}, TokenUsage() - - if provider == "google": - result = await self._call_google_json(client, model_name, prompt) - else: - result = await self._call_openai_json(client, model_name, prompt) - - if result is None: - logger.warning("LLM call returned no result, using default score") - return {"score": 0.5, "reasoning": "No result from LLM"}, TokenUsage() - - return result - - async def _call_google_json( - self, client: object, model: str, prompt: str - ) -> tuple[dict, TokenUsage] | None: - """Call Google GenAI and return (parsed JSON dict, TokenUsage), or None on failure.""" try: - from google.genai import types - - config = types.GenerateContentConfig( + text, inp, out, cache = await call_llm( + judge_model, + prompt, temperature=0.3, - response_mime_type="application/json", - ) - response = await client.aio.models.generate_content( # type: ignore[union-attr] - model=model, - contents=prompt, - config=config, + json_mode=True, + use_vertex_ai=use_vertex_ai, ) - text = (response.text or "").strip() if not text: - raise ValueError("Google GenAI returned empty content") - - usage = TokenUsage() - meta = getattr(response, "usage_metadata", None) - if meta is not None: - usage = TokenUsage( - input_tokens=getattr(meta, "prompt_token_count", 0) or 0, - output_tokens=getattr(meta, "candidates_token_count", 0) or 0, - cache_read_tokens=getattr(meta, "cached_content_token_count", 0) or 0, - ) - return json.loads(text), usage - except Exception as e: - logger.warning("Google GenAI call failed: %s", e) + return None + data = json.loads(text) + usage = TokenUsage(input_tokens=inp, output_tokens=out, cache_read_tokens=cache) + return data, usage + except (ImportError, ValueError, Exception): return None - async def _call_openai_json( - self, client: object, model: str, prompt: str - ) -> tuple[dict, TokenUsage] | None: - """Call OpenAI and return (parsed JSON dict, TokenUsage), or None on failure.""" + async def _call_openai_json(self, prompt: str) -> tuple[dict, TokenUsage] | None: + """Call OpenAI LLM and return (parsed JSON dict, token usage), or None if unavailable.""" + judge_model: str = self.config.judge_model # type: ignore[attr-defined] + api_style: str = getattr(self.config, "api_style", "responses") # type: ignore[attr-defined] try: - response = await client.chat.completions.create( # type: ignore[union-attr] - model=model, - messages=[{"role": "user", "content": prompt}], + text, inp, out, cache = await call_llm( + judge_model, + prompt, temperature=0.3, - response_format={"type": "json_object"}, + json_mode=True, + api_style=api_style, # type: ignore[arg-type] ) - text = (response.choices[0].message.content or "").strip() if not text: - raise ValueError("OpenAI returned empty content") - - usage = TokenUsage() - ru = getattr(response, "usage", None) - if ru is not None: - cached = 0 - details = getattr(ru, "prompt_tokens_details", None) - if details is not None: - cached = getattr(details, "cached_tokens", 0) or 0 - usage = TokenUsage( - input_tokens=getattr(ru, "prompt_tokens", 0) or 0, - output_tokens=getattr(ru, "completion_tokens", 0) or 0, - cache_read_tokens=cached, - ) - return json.loads(text), usage - except Exception as e: - logger.warning("OpenAI call failed: %s", e) + return None + data = json.loads(text) + usage = TokenUsage(input_tokens=inp, output_tokens=out, cache_read_tokens=cache) + return data, usage + except (ImportError, ValueError, Exception): return None - async def _call_llm_score(self, prompt: str) -> tuple[float, str, TokenUsage]: - """Call the judge LLM and return (score, reasoning, token_usage). + async def _call_llm_json(self, prompt: str) -> tuple[dict, TokenUsage]: + """Call the judge LLM and return (parsed JSON dict, token usage).""" + judge_model: str = self.config.judge_model # type: ignore[attr-defined] + use_vertex_ai: bool = getattr(self.config, "use_vertex_ai", False) # type: ignore[attr-defined] + provider, _ = _parse_model_provider(judge_model, use_vertex_ai=use_vertex_ai) + + # Try provider-specific methods first + result = None + if provider == "google": + result = await self._call_google_json(prompt) + elif provider == "openai": + result = await self._call_openai_json(prompt) - Convenience wrapper around :meth:`_call_llm_json` that extracts - the ``score`` and ``reasoning`` fields. - """ + if result is not None: + return result + + # If provider-specific method failed or is not available, return default + logger.warning("No LLM provider available for model: %s", judge_model) + return {"score": 0.5, "reasoning": "No LLM provider available"}, TokenUsage() + + async def _call_llm_score(self, prompt: str) -> tuple[float, str, TokenUsage]: + """Call the judge LLM and return (score, reasoning, token_usage).""" result, usage = await self._call_llm_json(prompt) return float(result.get("score", 0.0)), result.get("reasoning", ""), usage diff --git a/agentflow/qa/evaluation/simulators/user_simulator.py b/agentflow/qa/evaluation/simulators/user_simulator.py index 0642afe..8baf73f 100644 --- a/agentflow/qa/evaluation/simulators/user_simulator.py +++ b/agentflow/qa/evaluation/simulators/user_simulator.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, Field +from agentflow.core.llm.caller import call_llm from agentflow.qa.evaluation.token_usage import TokenUsage @@ -175,6 +176,7 @@ def __init__( max_turns: int = 10, config: UserSimulatorConfig | None = None, criteria: list[BaseCriterion] | None = None, + api_style: str = "responses", ): """Initialize the user simulator. @@ -184,15 +186,19 @@ def __init__( max_turns: Default maximum turns per scenario. config: Optional configuration override. criteria: Optional list of BaseCriterion to run after simulation. + api_style: OpenAI API style — ``"responses"`` (default) or + ``"chat"`` for legacy/third-party models. """ if config: self.model = config.model self.temperature = config.temperature self.max_turns = config.max_invocations + self.api_style: str = getattr(config, "api_style", "responses") else: self.model = model self.temperature = temperature self.max_turns = max_turns + self.api_style = api_style self.criteria: list[BaseCriterion] = criteria or [] @@ -501,88 +507,22 @@ async def _evaluate_simulation( return scores, details, results async def _call_llm(self, prompt: str) -> tuple[str, TokenUsage]: - """Call the LLM for user simulation. - - Uses Google GenAI as primary, OpenAI as fallback. - Returns (text, TokenUsage). - """ - from agentflow.qa.evaluation.criteria.llm_utils import _parse_model_provider - - provider, model_name = _parse_model_provider(self.model) - - if provider == "google": - text, usage = await self._call_google(model_name, prompt) - if text is not None: - return text, usage - - # OpenAI path - text, usage = await self._call_openai( - self.model if provider == "openai" else model_name, - prompt, - ) - if text is not None: - return text, usage - - # Fallback: try Google if we haven't yet - if provider != "google": - text, usage = await self._call_google(model_name, prompt) - if text is not None: - return text, usage - - return "I have a follow-up question.", TokenUsage() - - async def _call_google(self, model: str, prompt: str) -> tuple[str | None, TokenUsage]: - """Call Google GenAI for user simulation. Returns (text, TokenUsage).""" + """Call the LLM for user simulation. Returns (text, TokenUsage).""" try: - from google import genai - from google.genai import types - - client = genai.Client() - config = types.GenerateContentConfig(temperature=self.temperature) - response = await client.aio.models.generate_content( - model=model, - contents=prompt, - config=config, + text, inp, out, cache = await call_llm( + self.model, + prompt, + temperature=self.temperature, + api_style=self.api_style, # type: ignore[arg-type] ) - usage = TokenUsage() - meta = getattr(response, "usage_metadata", None) - if meta is not None: - usage = TokenUsage( - input_tokens=getattr(meta, "prompt_token_count", 0) or 0, - output_tokens=getattr(meta, "candidates_token_count", 0) or 0, - cache_read_tokens=getattr(meta, "cached_content_token_count", 0) or 0, - ) - return (response.text or "").strip(), usage - except ImportError: - return None, TokenUsage() - except Exception as e: - logger.warning("Google GenAI call failed (%s): %s", type(e).__name__, e) - return None, TokenUsage() + except Exception as exc: + logger.warning("LLM call failed (%s): %s", type(exc).__name__, exc) + return "I have a follow-up question.", TokenUsage() - async def _call_openai(self, model: str, prompt: str) -> tuple[str | None, TokenUsage]: - """Call OpenAI for user simulation. Returns (text, TokenUsage).""" - try: - from openai import AsyncOpenAI + if not text: + return "I have a follow-up question.", TokenUsage() - client = AsyncOpenAI() - response = await client.chat.completions.create( - model=model, - messages=[{"role": "user", "content": prompt}], - temperature=self.temperature, - ) - usage = TokenUsage() - raw = getattr(response, "usage", None) - if raw is not None: - usage = TokenUsage( - input_tokens=getattr(raw, "prompt_tokens", 0) or 0, - output_tokens=getattr(raw, "completion_tokens", 0) or 0, - ) - return (response.choices[0].message.content or "").strip(), usage - except ImportError: - return None, TokenUsage() - except Exception as e: - logger.warning("OpenAI call failed (%s): %s", type(e).__name__, e) - return None, TokenUsage() + return text, TokenUsage(input_tokens=inp, output_tokens=out, cache_read_tokens=cache) def _extract_response(self, result: dict[str, Any]) -> str: """Extract text response from graph result. diff --git a/agentflow/storage/store/store_schema.py b/agentflow/storage/store/store_schema.py index fd8b142..da3826b 100644 --- a/agentflow/storage/store/store_schema.py +++ b/agentflow/storage/store/store_schema.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable from contextlib import suppress from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any from injectq import InjectQ @@ -28,7 +28,7 @@ def _generate_memory_id() -> str: return secrets.token_hex(16) -class RetrievalStrategy(str, Enum): +class RetrievalStrategy(StrEnum): """Memory retrieval strategies.""" SIMILARITY = "similarity" # Vector similarity search @@ -38,7 +38,7 @@ class RetrievalStrategy(str, Enum): GRAPH_TRAVERSAL = "graph_traversal" # Knowledge graph navigation -class DistanceMetric(str, Enum): +class DistanceMetric(StrEnum): """Supported distance metrics for vector similarity.""" COSINE = "cosine" @@ -47,7 +47,7 @@ class DistanceMetric(str, Enum): MANHATTAN = "manhattan" -class MemoryType(str, Enum): +class MemoryType(StrEnum): """Types of memories that can be stored.""" EPISODIC = "episodic" # Conversation memories @@ -76,7 +76,7 @@ class MemorySearchResult(BaseModel): @classmethod def validate_vector(cls, v): if v is not None and ( - not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v) + not isinstance(v, list) or any(not isinstance(x, int | float) for x in v) ): raise ValueError("vector must be list[float] or None") return v diff --git a/tests/state/test_summary_context_manager.py b/tests/state/test_summary_context_manager.py new file mode 100644 index 0000000..0ae2015 --- /dev/null +++ b/tests/state/test_summary_context_manager.py @@ -0,0 +1,301 @@ +"""Tests for SummaryContextManager.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentflow.core.state.agent_state import AgentState +from agentflow.core.state.message import Message +from agentflow.core.state.message_block import TextBlock, ToolCallBlock, ToolResultBlock +from agentflow.core.state.summary_context_manager import ( + SummaryContextManager, + _estimate_tokens, + _messages_to_text, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_msgs(*roles_and_texts: tuple[str, str]) -> list[Message]: + return [Message.text_message(text, role=role) for role, text in roles_and_texts] + + +def _conv(n_user: int, include_system: bool = True) -> list[Message]: + msgs: list[Message] = [] + if include_system: + msgs.append(Message.text_message("System prompt", role="system")) + for i in range(n_user): + msgs.append(Message.text_message(f"User message {i + 1}", role="user")) + msgs.append(Message.text_message(f"Assistant reply {i + 1}", role="assistant")) + return msgs + + +# --------------------------------------------------------------------------- +# Unit tests — no I/O +# --------------------------------------------------------------------------- + +class TestEstimateTokens: + def test_empty(self): + assert _estimate_tokens([]) == 1 + + def test_basic(self): + msgs = _make_msgs(("user", "hello world")) # 11 chars → 2 tokens + assert _estimate_tokens(msgs) >= 1 + + def test_longer_text(self): + msgs = _make_msgs(("user", "a" * 400)) # 400 chars → 100 tokens + assert _estimate_tokens(msgs) == 100 + + def test_tool_calls_counted(self): + msg = Message( + role="assistant", + content=[ToolCallBlock(id="c1", name="search", args={"q": "x"})], + tools_calls=[{"id": "c1", "name": "search", "args": {"q": "x"}}], + ) + assert _estimate_tokens([msg]) >= 1 + + +class TestMessagesToText: + def test_basic_roles(self): + msgs = _make_msgs(("user", "hi"), ("assistant", "hello")) + text = _messages_to_text(msgs) + assert "USER: hi" in text + assert "ASSISTANT: hello" in text + + def test_empty_messages_skipped(self): + msgs = [Message(role="user", content=[])] + assert _messages_to_text(msgs) == "" + + def test_tool_calls_included(self): + msg = Message( + role="assistant", + content=[], + tools_calls=[{"name": "search"}], + ) + text = _messages_to_text([msg]) + assert "Tool calls" in text + + +class TestShouldSummarize: + def test_message_count_trigger(self): + mgr = SummaryContextManager("gpt-4o-mini", max_messages=5, token_budget=None) + msgs = _conv(3) # 1 system + 6 msgs = 7 total + assert mgr._should_summarize(msgs) is True + + def test_message_count_no_trigger(self): + mgr = SummaryContextManager("gpt-4o-mini", max_messages=20, token_budget=None) + msgs = _conv(3) + assert mgr._should_summarize(msgs) is False + + def test_token_budget_trigger(self): + mgr = SummaryContextManager("gpt-4o-mini", max_messages=None, token_budget=1) + msgs = _make_msgs(("user", "hello world")) + assert mgr._should_summarize(msgs) is True + + def test_token_budget_no_trigger(self): + mgr = SummaryContextManager("gpt-4o-mini", max_messages=None, token_budget=99999) + msgs = _conv(2) + assert mgr._should_summarize(msgs) is False + + def test_either_trigger_fires(self): + mgr = SummaryContextManager("gpt-4o-mini", max_messages=100, token_budget=1) + msgs = _make_msgs(("user", "a" * 100)) # ~25 tokens > budget of 1 + assert mgr._should_summarize(msgs) is True + + +class TestSplitContext: + def test_keeps_system_messages(self): + msgs = _conv(5) # system + 10 msgs + mgr = SummaryContextManager("gpt-4o-mini", keep_recent=4) + to_summarize, remaining = mgr._split_context(msgs) + assert all(m.role == "system" for m in remaining if m.role == "system") + assert not any(m.role == "system" for m in to_summarize) + + def test_keep_recent_respected(self): + msgs = _conv(6, include_system=False) # 12 messages + mgr = SummaryContextManager("gpt-4o-mini", keep_recent=4) + to_summarize, remaining = mgr._split_context(msgs) + assert len(remaining) == 4 + assert len(to_summarize) == 8 + + def test_no_split_when_small(self): + msgs = _conv(2) # 5 total (system + 4) + mgr = SummaryContextManager("gpt-4o-mini", keep_recent=8) + to_summarize, remaining = mgr._split_context(msgs) + assert to_summarize == [] + assert len(remaining) == len(msgs) + + +# --------------------------------------------------------------------------- +# Async tests — mocked LLM +# --------------------------------------------------------------------------- + +_CALL_LLM = "agentflow.core.state.summary_context_manager.call_llm" + + +@pytest.mark.anyio(loop_scope="function") +async def test_atrim_context_openai(): + mgr = SummaryContextManager("gpt-4o-mini", max_messages=5, keep_recent=2) + state = AgentState(context=_conv(4)) # 9 messages → triggers max_messages=5 + + with patch(_CALL_LLM, new=AsyncMock(return_value=("Summary text.", 10, 5, 0))) as mock_call: + result = await mgr.atrim_context(state) + + assert result.context_summary == "Summary text." + assert len(result.context) <= len(_conv(4)) + mock_call.assert_called_once() + + +@pytest.mark.anyio +async def test_atrim_context_google(): + mgr = SummaryContextManager("gemini-2.0-flash", max_messages=5, keep_recent=2) + state = AgentState(context=_conv(4)) + + with patch(_CALL_LLM, new=AsyncMock(return_value=("Google summary.", 8, 4, 0))) as mock_call: + result = await mgr.atrim_context(state) + + assert result.context_summary == "Google summary." + mock_call.assert_called_once() + + +@pytest.mark.anyio +async def test_rolling_summary_appended(): + """Subsequent summarisations append to existing context_summary.""" + mgr = SummaryContextManager("gpt-4o-mini", max_messages=5, keep_recent=2) + state = AgentState(context=_conv(4), context_summary="Previous summary.") + + with patch(_CALL_LLM, new=AsyncMock(return_value=("New chunk.", 10, 5, 0))): + result = await mgr.atrim_context(state) + + assert result.context_summary == "Previous summary.\n\nNew chunk." + + +@pytest.mark.anyio +async def test_no_summarize_when_below_threshold(): + mgr = SummaryContextManager("gpt-4o-mini", max_messages=100, token_budget=None) + state = AgentState(context=_conv(2)) + original_context = list(state.context) + + with patch(_CALL_LLM, new=AsyncMock()) as mock_call: + result = await mgr.atrim_context(state) + + mock_call.assert_not_called() + assert result.context == original_context + assert result.context_summary is None + + +@pytest.mark.anyio +async def test_llm_failure_leaves_context_unchanged(): + """If call_llm raises, context must not be modified.""" + mgr = SummaryContextManager("gpt-4o-mini", max_messages=5, keep_recent=2) + state = AgentState(context=_conv(4)) + original_len = len(state.context) + + with patch(_CALL_LLM, new=AsyncMock(side_effect=RuntimeError("API down"))): + result = await mgr.atrim_context(state) + + assert len(result.context) == original_len + assert result.context_summary is None + + +@pytest.mark.anyio +async def test_token_budget_triggers_with_long_messages(): + long_text = "x" * 4000 # ~1000 tokens + msgs = [Message.text_message(long_text, role="user")] + mgr = SummaryContextManager("gpt-4o-mini", max_messages=None, token_budget=500, keep_recent=0) + state = AgentState(context=msgs) + + with patch(_CALL_LLM, new=AsyncMock(return_value=("Compressed.", 20, 10, 0))): + result = await mgr.atrim_context(state) + + assert result.context_summary == "Compressed." + + +# --------------------------------------------------------------------------- +# remove_tool_msgs +# --------------------------------------------------------------------------- + +def _complete_tool_sequence(idx: int = 1) -> list[Message]: + """Build a complete tool sequence: user → ai-with-tool-call → tool-result → ai-final. + + remove_tool_messages only strips COMPLETE sequences, so all four messages are needed. + """ + tool_call = Message( + role="assistant", + content=[ToolCallBlock(id=f"c{idx}", name="search", args={})], + tools_calls=[{"id": f"c{idx}", "name": "search", "args": {}}], + ) + tool_result = Message.tool_message([ToolResultBlock(call_id=f"c{idx}", output="result")]) + ai_final = Message.text_message(f"Done {idx}", role="assistant") + return [tool_call, tool_result, ai_final] + + +@pytest.mark.anyio +async def test_remove_tool_msgs_strips_before_threshold(): + """Complete tool sequences are stripped; remaining messages stay below threshold.""" + # 1 complete tool sequence (3 msgs) + 2 plain messages = 5 total + # After stripping: 2 messages — below max_messages=3, so no summarisation + user1 = Message.text_message("first", role="user") + user2 = Message.text_message("second", role="user") + msgs = _complete_tool_sequence(1) + [user1, user2] + + mgr = SummaryContextManager("gpt-4o-mini", max_messages=3, remove_tool_msgs=True) + state = AgentState(context=msgs) + result = await mgr.atrim_context(state) + + # Tool messages stripped from retained context + assert all(m.role != "tool" for m in result.context) + assert not any( + isinstance(b, ToolCallBlock) + for m in result.context + for b in m.content + ) + # Threshold not exceeded after stripping — no summary + assert result.context_summary is None + + +@pytest.mark.anyio +async def test_remove_tool_msgs_then_summarizes(): + """Complete tool sequences stripped, then summarisation fires if threshold still exceeded.""" + # _conv(4) = system + 8 msgs; add 2 complete tool sequences (6 msgs) = 15 total + # After stripping 2 sequences (6 msgs): 9 messages > max_messages=5 → summarise + msgs = _conv(4) + _complete_tool_sequence(1) + _complete_tool_sequence(2) + + mgr = SummaryContextManager( + "gpt-4o-mini", max_messages=5, keep_recent=2, remove_tool_msgs=True + ) + state = AgentState(context=msgs) + + with patch(_CALL_LLM, new=AsyncMock(return_value=("Summary.", 10, 5, 0))): + result = await mgr.atrim_context(state) + + assert result.context_summary == "Summary." + assert all(m.role != "tool" for m in result.context) + + +# --------------------------------------------------------------------------- +# Provider auto-detection +# --------------------------------------------------------------------------- + +def test_provider_detected_google(): + mgr = SummaryContextManager("gemini-2.0-flash") + assert mgr._provider == "google" + + +def test_provider_detected_openai(): + mgr = SummaryContextManager("gpt-4o-mini") + assert mgr._provider == "openai" + + +def test_custom_summary_prompt(): + mgr = SummaryContextManager("gpt-4o-mini", summary_system_prompt="Be brief.") + assert mgr.summary_system_prompt == "Be brief." + + +def test_default_summary_prompt_set(): + mgr = SummaryContextManager("gpt-4o-mini") + assert len(mgr.summary_system_prompt) > 0 diff --git a/tests/utils/test_call_llm.py b/tests/utils/test_call_llm.py new file mode 100644 index 0000000..319ada3 --- /dev/null +++ b/tests/utils/test_call_llm.py @@ -0,0 +1,159 @@ +"""Tests for agentflow.core.llm.caller.call_llm.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentflow.core.llm.caller import ( + _extract_responses_text, + call_llm, +) + + +# --------------------------------------------------------------------------- +# Provider dispatch +# --------------------------------------------------------------------------- + +_DETECT = "agentflow.core.llm.caller.detect_provider" +_CREATE = "agentflow.core.llm.caller.create_llm_client" +_CALL_GOOGLE = "agentflow.core.llm.caller._call_google" +_CALL_RESP = "agentflow.core.llm.caller._call_openai_responses" +_CALL_CHAT = "agentflow.core.llm.caller._call_openai_chat" + +_DUMMY = ("text", 10, 5, 0) + + +@pytest.mark.anyio +async def test_google_model_dispatches_to_google(): + with ( + patch(_DETECT, return_value="google"), + patch(_CREATE, return_value=MagicMock()), + patch(_CALL_GOOGLE, new=AsyncMock(return_value=_DUMMY)) as mock, + ): + result = await call_llm("gemini-2.0-flash", "hello") + + mock.assert_called_once() + assert result == _DUMMY + + +@pytest.mark.anyio +async def test_openai_default_dispatches_to_responses(): + with ( + patch(_DETECT, return_value="openai"), + patch(_CREATE, return_value=MagicMock()), + patch(_CALL_RESP, new=AsyncMock(return_value=_DUMMY)) as mock, + ): + result = await call_llm("gpt-4o-mini", "hello") + + mock.assert_called_once() + assert result == _DUMMY + + +@pytest.mark.anyio +async def test_openai_chat_style_dispatches_to_chat(): + with ( + patch(_DETECT, return_value="openai"), + patch(_CREATE, return_value=MagicMock()), + patch(_CALL_CHAT, new=AsyncMock(return_value=_DUMMY)) as mock, + ): + result = await call_llm("gpt-4o-mini", "hello", api_style="chat") + + mock.assert_called_once() + assert result == _DUMMY + + +@pytest.mark.anyio +async def test_openai_responses_style_explicit(): + """Explicitly passing api_style='responses' still hits the Responses path.""" + with ( + patch(_DETECT, return_value="openai"), + patch(_CREATE, return_value=MagicMock()), + patch(_CALL_RESP, new=AsyncMock(return_value=_DUMMY)) as mock_resp, + patch(_CALL_CHAT, new=AsyncMock(return_value=_DUMMY)) as mock_chat, + ): + await call_llm("gpt-4o-mini", "hello", api_style="responses") + + mock_resp.assert_called_once() + mock_chat.assert_not_called() + + +@pytest.mark.anyio +async def test_api_style_irrelevant_for_google(): + """api_style has no effect when the provider is Google.""" + with ( + patch(_DETECT, return_value="google"), + patch(_CREATE, return_value=MagicMock()), + patch(_CALL_GOOGLE, new=AsyncMock(return_value=_DUMMY)) as mock_google, + patch(_CALL_CHAT, new=AsyncMock(return_value=_DUMMY)) as mock_chat, + ): + await call_llm("gemini-2.0-flash", "hello", api_style="chat") + + mock_google.assert_called_once() + mock_chat.assert_not_called() + + +# --------------------------------------------------------------------------- +# Parameters forwarded correctly +# --------------------------------------------------------------------------- + +@pytest.mark.anyio +async def test_system_prompt_forwarded_to_responses(): + with ( + patch(_DETECT, return_value="openai"), + patch(_CREATE, return_value=MagicMock()), + patch(_CALL_RESP, new=AsyncMock(return_value=_DUMMY)) as mock, + ): + await call_llm("gpt-4o-mini", "hi", system_prompt="Be brief.", json_mode=True) + + _, kwargs = mock.call_args + assert kwargs["system_prompt"] == "Be brief." + assert kwargs["json_mode"] is True + + +@pytest.mark.anyio +async def test_system_prompt_forwarded_to_chat(): + with ( + patch(_DETECT, return_value="openai"), + patch(_CREATE, return_value=MagicMock()), + patch(_CALL_CHAT, new=AsyncMock(return_value=_DUMMY)) as mock, + ): + await call_llm("gpt-4o-mini", "hi", system_prompt="Be brief.", api_style="chat") + + _, kwargs = mock.call_args + assert kwargs["system_prompt"] == "Be brief." + + +# --------------------------------------------------------------------------- +# _extract_responses_text +# --------------------------------------------------------------------------- + +def _make_response(output_text=None, output=None): + r = MagicMock() + r.output_text = output_text + r.output = output or [] + return r + + +def test_extract_uses_output_text_property(): + r = _make_response(output_text=" hello ") + assert _extract_responses_text(r) == "hello" + + +def test_extract_falls_back_to_output_items(): + part = MagicMock() + part.type = "output_text" + part.text = "fallback" + + item = MagicMock() + item.type = "message" + item.content = [part] + + r = _make_response(output_text=None, output=[item]) + assert _extract_responses_text(r) == "fallback" + + +def test_extract_returns_empty_when_no_text(): + r = _make_response(output_text=None, output=[]) + assert _extract_responses_text(r) == ""