diff --git a/.env.public b/.env.public index 2bc835cde..306c1ef93 100644 --- a/.env.public +++ b/.env.public @@ -13,3 +13,23 @@ WATSONX_URL=https://us-south.ml.cloud.ibm.com # optional # ── LiteLLM (plan-execute runner) ──────────────────────────────────────────── LITELLM_API_KEY= LITELLM_BASE_URL= + +# ── LLM generation parameters (all runners) ────────────────────────────────── +# Applied to every LLM call across plan-execute, claude-agent, openai-agent, +# deep-agent, and the FMSR server. All variables are optional; unset variables +# fall back to safe defaults shown in comments. +LLM_MAX_TOKENS= # int — max output tokens (default: 4096) +LLM_TEMPERATURE= # float — sampling temperature (default: 0.0) +LLM_TOP_P= # float — nucleus sampling top-p (default: omit) +LLM_REASONING_EFFORT= # none|low|medium|high|max (default: none) + # Controls extended thinking / reasoning depth. + # Mapped per provider: Claude uses effort+thinking, + # OpenAI uses Reasoning.effort (max→xhigh), + # LiteLLM passes reasoning_effort directly. + # Stripped with a warning on unsupported models + # (e.g. WatsonX Llama). +LLM_THINKING_BUDGET_TOKENS= # int — explicit thinking budget (default: omit) + # Only used for legacy Anthropic budget-style + # thinking (pre-4.6 models). Ignored on 4.6+. +LLM_STOP= # comma-separated stop sequences (default: omit) + # Example: LLM_STOP=",END" diff --git a/src/agent/claude_agent/runner.py b/src/agent/claude_agent/runner.py index 9a11b3e84..37147c0ed 100644 --- a/src/agent/claude_agent/runner.py +++ b/src/agent/claude_agent/runner.py @@ -26,8 +26,10 @@ from observability import agent_run_span, persist_trajectory +from llm.generation import GenerationParams from .._litellm import LITELLM_PREFIX, resolve_model from .._prompts import AGENT_SYSTEM_PROMPT +from ..generation_maps import to_claude_agent_options from ..models import AgentResult, ToolCall, Trajectory, TurnRecord from ..runner import AgentRunner @@ -95,9 +97,12 @@ def __init__( model: str = _DEFAULT_MODEL, max_turns: int = 30, permission_mode: str = "bypassPermissions", + *, + generation: GenerationParams | None = None, ) -> None: - super().__init__(llm, server_paths) + super().__init__(llm, server_paths, generation=generation) self._model = resolve_model(model) + self._model_id = model self._sdk_env = _sdk_env(model) self._max_turns = max_turns self._permission_mode = permission_mode @@ -123,6 +128,7 @@ async def run(self, question: str) -> AgentResult: permission_mode=self._permission_mode, env=self._sdk_env, ) + to_claude_agent_options(options, self._generation, self._model_id) _log.info("ClaudeAgentRunner: starting query (model=%s)", self._model) answer = "" diff --git a/src/agent/deep_agent/runner.py b/src/agent/deep_agent/runner.py index 3d975d55c..a31a5c5fe 100644 --- a/src/agent/deep_agent/runner.py +++ b/src/agent/deep_agent/runner.py @@ -27,8 +27,10 @@ from observability import agent_run_span, persist_trajectory +from llm.generation import GenerationParams from .._litellm import LITELLM_PREFIX, resolve_model from .._prompts import AGENT_SYSTEM_PROMPT +from ..generation_maps import to_chat_openai_kwargs from ..models import AgentResult, ToolCall, Trajectory, TurnRecord from ..runner import AgentRunner @@ -39,7 +41,7 @@ _DEFAULT_MODEL = "litellm_proxy/aws/claude-opus-4-6" -def _build_chat_model(model_id: str): +def _build_chat_model(model_id: str, extra_kwargs: dict | None = None): """Construct a LangChain chat model for *model_id*. When the ID uses the ``litellm_proxy/`` prefix, a :class:`ChatOpenAI` @@ -47,6 +49,7 @@ def _build_chat_model(model_id: str): ``LITELLM_API_KEY``). Otherwise the model string is passed to ``init_chat_model`` so any provider supported by LangChain can be used. """ + extra_kwargs = extra_kwargs or {} if model_id.startswith(LITELLM_PREFIX): base_url = os.environ.get("LITELLM_BASE_URL") api_key = os.environ.get("LITELLM_API_KEY") @@ -61,11 +64,12 @@ def _build_chat_model(model_id: str): model=resolve_model(model_id), base_url=base_url, api_key=api_key, + **extra_kwargs, ) from langchain.chat_models import init_chat_model - return init_chat_model(model_id) + return init_chat_model(model_id, **extra_kwargs) def _build_mcp_connections( @@ -168,15 +172,20 @@ def __init__( server_paths: dict[str, Path | str] | None = None, model: str = _DEFAULT_MODEL, recursion_limit: int = 100, + *, + generation: GenerationParams | None = None, ) -> None: - super().__init__(llm, server_paths) + super().__init__(llm, server_paths, generation=generation) self._model_id = model self._recursion_limit = recursion_limit @cached_property def _chat_model(self): """LangChain chat model, built once per runner instance.""" - return _build_chat_model(self._model_id) + return _build_chat_model( + self._model_id, + to_chat_openai_kwargs(self._generation, self._model_id), + ) async def run(self, question: str) -> AgentResult: """Run the deep-agents loop for *question*. diff --git a/src/agent/generation_maps.py b/src/agent/generation_maps.py new file mode 100644 index 000000000..2c38c8a47 --- /dev/null +++ b/src/agent/generation_maps.py @@ -0,0 +1,185 @@ +"""Maps :class:`~llm.GenerationParams` to each agent SDK's native config types. + +Import chain is intentionally one-way: + llm.generation ← (no agent deps) + agent.generation_maps → imports from SDK packages only when called + +Three public helpers, one per SDK: + + to_claude_agent_options(options, params, model_id) + Applies reasoning / thinking + strips-and-warns other params. + + to_model_settings(params, model_id) -> ModelSettings + Returns an openai-agents ModelSettings. + + to_chat_openai_kwargs(params, model_id) -> dict + Returns init / bind kwargs for langchain_openai.ChatOpenAI. +""" + +from __future__ import annotations + +import logging + +from llm.generation import ( + GenerationParams, + EFFORT_TO_OPENAI, + reasoning_supported, +) + +_log = logging.getLogger(__name__) + + +# ── Claude Agent SDK ────────────────────────────────────────────────────────── + + +def to_claude_agent_options( + options, # claude_agent_sdk.ClaudeAgentOptions (avoid hard import at module level) + params: GenerationParams, + model_id: str, +) -> None: + """Mutate *options* in-place with generation params. + + Claude Agent SDK fields handled natively: + - ``effort`` → ``options.effort`` + - ``thinking`` → ``options.thinking`` + + All other params (max_tokens, temperature, top_p, stop) are forwarded via + ``options.extra_args`` if non-default, with a warning that support depends + on the underlying CLI version. + """ + effort = params.reasoning_effort + + if reasoning_supported(model_id): + if effort == "none": + from claude_agent_sdk.types import ThinkingConfigDisabled + + options.thinking = ThinkingConfigDisabled(type="disabled") + options.effort = None + elif params.thinking_budget_tokens is not None: + from claude_agent_sdk.types import ThinkingConfigEnabled + + options.thinking = ThinkingConfigEnabled( + type="enabled", + budget_tokens=params.thinking_budget_tokens, + ) + options.effort = None + else: + from claude_agent_sdk.types import ThinkingConfigAdaptive + + options.thinking = ThinkingConfigAdaptive(type="adaptive") + options.effort = effort # type: ignore[assignment] + elif effort != "none": + _log.warning( + "reasoning_effort=%r requested but model %r does not support " + "reasoning on claude-agent — stripping thinking kwargs.", + effort, + model_id, + ) + + extra: dict[str, str | None] = dict(options.extra_args or {}) + + # max_tokens via extra_args (CLI flag name; strip+warn if unsupported at runtime) + if params.max_tokens != GenerationParams.max_tokens: + _log.warning( + "claude-agent: max_tokens=%d forwarded via extra_args; " + "support depends on the installed Claude Code CLI version.", + params.max_tokens, + ) + extra["max-tokens"] = str(params.max_tokens) + + if params.temperature != GenerationParams.temperature: + _log.warning( + "claude-agent: temperature=%.3g — ClaudeAgentOptions has no " + "native temperature field; stripping.", + params.temperature, + ) + + if params.top_p is not None: + _log.warning( + "claude-agent: top_p=%.3g — ClaudeAgentOptions has no native " + "top_p field; stripping.", + params.top_p, + ) + + if params.stop: + _log.warning( + "claude-agent: stop sequences — ClaudeAgentOptions has no native " + "stop field; stripping.", + ) + + options.extra_args = extra + + +# ── OpenAI Agents SDK ───────────────────────────────────────────────────────── + + +def to_model_settings(params: GenerationParams, model_id: str): + """Return an ``agents.ModelSettings`` populated from *params*. + + ``reasoning_effort`` is mapped to ``ModelSettings.reasoning`` with the + OpenAI-compatible vocab (``max`` → ``xhigh``). Unsupported models get the + reasoning field stripped with a warning. + """ + from agents import ModelSettings + from openai.types.shared import Reasoning + + kwargs: dict = { + "max_tokens": params.max_tokens, + "temperature": params.temperature, + } + + if params.top_p is not None: + kwargs["top_p"] = params.top_p + + if params.stop: + kwargs["extra_args"] = {"stop": list(params.stop)} + + effort = params.reasoning_effort + if effort != "none": + if reasoning_supported(model_id): + openai_effort = EFFORT_TO_OPENAI[effort] + kwargs["reasoning"] = Reasoning(effort=openai_effort) # type: ignore[arg-type] + else: + _log.warning( + "reasoning_effort=%r requested but model %r does not support " + "reasoning on openai-agent — stripping.", + effort, + model_id, + ) + + return ModelSettings(**kwargs) + + +# ── LangChain ChatOpenAI (deep-agent) ───────────────────────────────────────── + + +def to_chat_openai_kwargs(params: GenerationParams, model_id: str) -> dict: + """Return init / ``.bind()`` kwargs for ``langchain_openai.ChatOpenAI``. + + Passes generation params through ``model_kwargs`` so the LiteLLM proxy + (which presents an OpenAI-compatible interface) forwards them correctly. + """ + kwargs: dict = { + "max_tokens": params.max_tokens, + "temperature": params.temperature, + } + + if params.top_p is not None: + kwargs["top_p"] = params.top_p + + if params.stop: + kwargs["stop"] = list(params.stop) + + effort = params.reasoning_effort + if effort != "none": + if reasoning_supported(model_id): + kwargs["reasoning_effort"] = EFFORT_TO_OPENAI[effort] + else: + _log.warning( + "reasoning_effort=%r requested but model %r does not support " + "reasoning on deep-agent — stripping.", + effort, + model_id, + ) + + return kwargs diff --git a/src/agent/openai_agent/runner.py b/src/agent/openai_agent/runner.py index 8dfccb48d..95abe2382 100644 --- a/src/agent/openai_agent/runner.py +++ b/src/agent/openai_agent/runner.py @@ -30,8 +30,10 @@ from observability import agent_run_span, persist_trajectory +from llm.generation import GenerationParams from .._litellm import LITELLM_PREFIX, resolve_model from .._prompts import AGENT_SYSTEM_PROMPT +from ..generation_maps import to_model_settings from ..models import AgentResult, ToolCall, Trajectory, TurnRecord from ..runner import AgentRunner @@ -193,8 +195,10 @@ def __init__( server_paths: dict[str, Path | str] | None = None, model: str = _DEFAULT_MODEL, max_turns: int = 30, + *, + generation: GenerationParams | None = None, ) -> None: - super().__init__(llm, server_paths) + super().__init__(llm, server_paths, generation=generation) self._model_id = model self._model = resolve_model(model) self._run_config = _build_run_config(model) @@ -227,6 +231,7 @@ async def run(self, question: str) -> AgentResult: instructions=AGENT_SYSTEM_PROMPT, mcp_servers=active_servers, model=self._model, + model_settings=to_model_settings(self._generation, self._model_id), ) _log.info( diff --git a/src/agent/plan_execute/runner.py b/src/agent/plan_execute/runner.py index ea684a650..d6c242c88 100644 --- a/src/agent/plan_execute/runner.py +++ b/src/agent/plan_execute/runner.py @@ -17,6 +17,7 @@ from pathlib import Path from llm import LLMBackend, LLMResult +from llm.generation import GenerationParams from observability import agent_run_span, persist_trajectory from .executor import Executor @@ -44,16 +45,26 @@ def reset(self) -> None: self.input_tokens = 0 self.output_tokens = 0 - def generate(self, prompt: str, temperature: float = 0.0) -> str: - result = self._inner.generate_with_usage(prompt, temperature) + def generate( + self, + prompt: str, + temperature: float = 0.0, + *, + params: GenerationParams | None = None, + ) -> str: + result = self._inner.generate_with_usage(prompt, temperature, params=params) self.input_tokens += result.input_tokens self.output_tokens += result.output_tokens return result.text def generate_with_usage( - self, prompt: str, temperature: float = 0.0 + self, + prompt: str, + temperature: float = 0.0, + *, + params: GenerationParams | None = None, ) -> LLMResult: - result = self._inner.generate_with_usage(prompt, temperature) + result = self._inner.generate_with_usage(prompt, temperature, params=params) self.input_tokens += result.input_tokens self.output_tokens += result.output_tokens return result @@ -62,6 +73,7 @@ def generate_with_usage( def model_id(self) -> str: return self._inner.model_id + _log = logging.getLogger(__name__) _SUMMARIZE_PROMPT = """\ @@ -102,8 +114,10 @@ def __init__( self, llm: LLMBackend, server_paths: dict[str, Path | str] | None = None, + *, + generation: GenerationParams | None = None, ) -> None: - super().__init__(llm, server_paths) + super().__init__(llm, server_paths, generation=generation) self._meter = _TokenMeter(llm) self._planner = Planner(self._meter) self._executor = Executor(self._meter, server_paths) diff --git a/src/agent/runner.py b/src/agent/runner.py index 1c06ec601..844f1206f 100644 --- a/src/agent/runner.py +++ b/src/agent/runner.py @@ -6,6 +6,7 @@ from pathlib import Path from llm import LLMBackend +from llm.generation import GenerationParams, from_env from .models import AgentResult @@ -30,17 +31,30 @@ class AgentRunner(ABC): return an :class:`AgentResult`. After ``super().__init__``, ``self._server_paths`` is always a concrete ``dict`` — either the caller's override, or a copy of :data:`DEFAULT_SERVER_PATHS`. + + Args: + llm: LLM backend (used by plan-execute; SDK-based runners accept + ``None`` for interface compatibility). + server_paths: MCP server specs. Defaults to :data:`DEFAULT_SERVER_PATHS`. + generation: Generation parameters applied to all LLM calls made by this + runner. Defaults to :func:`~llm.generation.from_env` so + env vars take effect without explicit construction. """ def __init__( self, llm: LLMBackend, server_paths: dict[str, Path | str] | None = None, + *, + generation: GenerationParams | None = None, ) -> None: self._llm = llm self._server_paths: dict[str, Path | str] = ( dict(DEFAULT_SERVER_PATHS) if server_paths is None else server_paths ) + self._generation: GenerationParams = ( + generation if generation is not None else from_env() + ) @abstractmethod async def run(self, question: str) -> AgentResult: diff --git a/src/agent/tests/conftest.py b/src/agent/tests/conftest.py index b3827ce6a..87a4867b8 100644 --- a/src/agent/tests/conftest.py +++ b/src/agent/tests/conftest.py @@ -11,7 +11,7 @@ class MockLLM(LLMBackend): def __init__(self, response: str = "") -> None: self._response = response - def generate(self, prompt: str, temperature: float = 0.0) -> str: + def generate(self, prompt: str, temperature: float = 0.0, **_kw) -> str: return self._response @@ -21,7 +21,7 @@ class SequentialMockLLM(LLMBackend): def __init__(self, responses: list[str]) -> None: self._responses = iter(responses) - def generate(self, prompt: str, temperature: float = 0.0) -> str: + def generate(self, prompt: str, temperature: float = 0.0, **_kw) -> str: return next(self._responses, "") diff --git a/src/agent/tests/test_runner.py b/src/agent/tests/test_runner.py index 6bdbac98b..0784f4978 100644 --- a/src/agent/tests/test_runner.py +++ b/src/agent/tests/test_runner.py @@ -142,11 +142,11 @@ class _UsageReportingLLM(LLMBackend): def __init__(self, items: list[tuple[str, int, int]]) -> None: self._items = iter(items) - def generate(self, prompt: str, temperature: float = 0.0) -> str: + def generate(self, prompt: str, temperature: float = 0.0, **_kw) -> str: return self.generate_with_usage(prompt, temperature).text def generate_with_usage( - self, prompt: str, temperature: float = 0.0 + self, prompt: str, temperature: float = 0.0, **_kw ) -> LLMResult: text, in_tok, out_tok = next(self._items, ("", 0, 0)) return LLMResult(text=text, input_tokens=in_tok, output_tokens=out_tok) diff --git a/src/evaluation/tests/test_scorers.py b/src/evaluation/tests/test_scorers.py index 8f2ac6b69..ea132b8c3 100644 --- a/src/evaluation/tests/test_scorers.py +++ b/src/evaluation/tests/test_scorers.py @@ -16,7 +16,7 @@ class _StubLLM(LLMBackend): def __init__(self, response: str) -> None: self._response = response - def generate(self, prompt: str, temperature: float = 0.0) -> str: + def generate(self, prompt: str, temperature: float = 0.0, **_kw) -> str: return self._response diff --git a/src/llm/__init__.py b/src/llm/__init__.py index ff6e606ed..db06f618a 100644 --- a/src/llm/__init__.py +++ b/src/llm/__init__.py @@ -1,6 +1,15 @@ """LLM backend for AssetOpsBench MCP.""" from .base import LLMBackend, LLMResult +from .generation import GenerationParams, from_env, resolve_params, reasoning_supported from .litellm import LiteLLMBackend -__all__ = ["LLMBackend", "LLMResult", "LiteLLMBackend"] +__all__ = [ + "LLMBackend", + "LLMResult", + "LiteLLMBackend", + "GenerationParams", + "from_env", + "resolve_params", + "reasoning_supported", +] diff --git a/src/llm/base.py b/src/llm/base.py index a6b085141..adedeb9bb 100644 --- a/src/llm/base.py +++ b/src/llm/base.py @@ -4,6 +4,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .generation import GenerationParams @dataclass(frozen=True) @@ -23,19 +27,38 @@ class LLMBackend(ABC): """Abstract interface for LLM backends.""" @abstractmethod - def generate(self, prompt: str, temperature: float = 0.0) -> str: - """Generate text given a prompt.""" + def generate( + self, + prompt: str, + temperature: float = 0.0, + *, + params: GenerationParams | None = None, + ) -> str: + """Generate text given a prompt. + + Args: + prompt: The prompt to send to the model. + temperature: Sampling temperature — **always** overrides any value + stored in *params* when provided (even ``0.0``). + params: Optional :class:`~llm.GenerationParams` override applied on + top of the backend's stored defaults. ``None`` means use + backend defaults unchanged. + """ ... def generate_with_usage( - self, prompt: str, temperature: float = 0.0 + self, + prompt: str, + temperature: float = 0.0, + *, + params: GenerationParams | None = None, ) -> LLMResult: """Generate text and report token usage. Default impl delegates to :meth:`generate` and reports zero usage — backends that can surface counts (e.g. LiteLLM) should override. """ - return LLMResult(text=self.generate(prompt, temperature)) + return LLMResult(text=self.generate(prompt, temperature, params=params)) @property def model_id(self) -> str: diff --git a/src/llm/generation.py b/src/llm/generation.py new file mode 100644 index 000000000..60c6742a5 --- /dev/null +++ b/src/llm/generation.py @@ -0,0 +1,162 @@ +"""Generation parameters shared across all LLM backends and agent runners. + +Single source of truth for configurable generation knobs. Every backend and +runner reads from :class:`GenerationParams`, which is populated from env vars +via :func:`from_env` and merged with per-constructor / per-call overrides via +:func:`resolve_params`. + +Environment variables (all optional): + + LLM_MAX_TOKENS int (default 4096) + LLM_TEMPERATURE float (default 0.0) + LLM_TOP_P float (default: omit) + LLM_REASONING_EFFORT none|low|medium|high|max (default: none) + LLM_THINKING_BUDGET_TOKENS int (default: omit) + LLM_STOP comma-separated stop sequences (default: omit) +""" + +from __future__ import annotations + +import dataclasses +import logging +import os +from dataclasses import dataclass, replace +from typing import Literal + +_log = logging.getLogger(__name__) + +ReasoningEffort = Literal["none", "low", "medium", "high", "max"] + +# Maps our canonical effort labels to the OpenAI Agents SDK's Reasoning.effort vocab. +EFFORT_TO_OPENAI: dict[str, str] = { + "none": "none", + "low": "low", + "medium": "medium", + "high": "high", + "max": "xhigh", +} + +# Model-id substrings that indicate a reasoning-capable model (checked on lowercased id). +_REASONING_SUBSTRINGS = ( + "claude", + "anthropic", + "o1", + "o3", + "gpt-5", + "gpt-o", +) + +# Prefixes whose models are never reasoning-capable regardless of the rest of the id. +_NO_REASONING_PREFIXES = ("watsonx/",) + + +def reasoning_supported(model_id: str) -> bool: + """Return ``True`` when *model_id* is known to support reasoning / thinking. + + Heuristic: checks the full model string (including ``litellm_proxy/...`` + tails) against known reasoning-capable families. WatsonX is always False. + Unknown models return False (safe default — unknown effort will be stripped + with a warning at the mapper level). + """ + lower = model_id.lower() + for prefix in _NO_REASONING_PREFIXES: + if lower.startswith(prefix): + return False + for substr in _REASONING_SUBSTRINGS: + if substr in lower: + return True + return False + + +@dataclass(frozen=True) +class GenerationParams: + """Immutable generation configuration shared across all backends/runners. + + Optional fields (``top_p``, ``stop``, ``thinking_budget_tokens``) default + to ``None``, which means **omit from API calls** — they are not sent to the + provider unless explicitly set. + + Merge instances with :func:`resolve_params`; the explicit ``temperature`` + argument on :meth:`~llm.LLMBackend.generate` **always** overrides this. + """ + + max_tokens: int = 4096 + temperature: float = 0.0 + reasoning_effort: ReasoningEffort = "none" + top_p: float | None = None + thinking_budget_tokens: int | None = None + stop: tuple[str, ...] | None = None + + +def from_env() -> GenerationParams: + """Build a :class:`GenerationParams` from environment variables. + + All variables are optional; missing or empty ones fall back to the + :class:`GenerationParams` dataclass defaults. + """ + kwargs: dict = {} + + if raw := os.environ.get("LLM_MAX_TOKENS"): + kwargs["max_tokens"] = int(raw) + + if raw := os.environ.get("LLM_TEMPERATURE"): + kwargs["temperature"] = float(raw) + + if raw := os.environ.get("LLM_TOP_P"): + kwargs["top_p"] = float(raw) + + if raw := os.environ.get("LLM_REASONING_EFFORT"): + effort = raw.strip().lower() + if effort not in ("none", "low", "medium", "high", "max"): + _log.warning( + "LLM_REASONING_EFFORT=%r is not a valid value " + "(none|low|medium|high|max); using 'none'.", + raw, + ) + effort = "none" + kwargs["reasoning_effort"] = effort + + if raw := os.environ.get("LLM_THINKING_BUDGET_TOKENS"): + kwargs["thinking_budget_tokens"] = int(raw) + + if raw := os.environ.get("LLM_STOP"): + parts = tuple(s.strip() for s in raw.split(",") if s.strip()) + if parts: + kwargs["stop"] = parts + + return GenerationParams(**kwargs) + + +def resolve_params( + base: GenerationParams, + *, + override: GenerationParams | None = None, + temperature: float | None = None, +) -> GenerationParams: + """Merge *override* onto *base*, then optionally pin *temperature*. + + Merge semantics: + - ``None`` fields in *override* **do not** replace values in *base* + (None means "not set / omit"). + - Non-``None`` override fields always win over *base*. + - *temperature*, when provided as a keyword argument, **always** wins + over any merged value — matches the ``generate(prompt, temperature=…)`` + contract. + + Returns a new frozen :class:`GenerationParams`. + """ + merged = base + + if override is not None: + changes = { + f.name: getattr(override, f.name) + for f in dataclasses.fields(override) + if getattr(override, f.name) is not None + } + if changes: + merged = replace(merged, **changes) + + if temperature is not None: + merged = replace(merged, temperature=temperature) + + return merged diff --git a/src/llm/litellm.py b/src/llm/litellm.py index 85067c7c1..2e19dbb38 100644 --- a/src/llm/litellm.py +++ b/src/llm/litellm.py @@ -14,13 +14,72 @@ from __future__ import annotations +import logging import os from .base import LLMBackend, LLMResult - +from .generation import ( + GenerationParams, + from_env, + reasoning_supported, + resolve_params, +) + +_log = logging.getLogger(__name__) _WATSONX_PREFIX = "watsonx/" +def to_litellm_kwargs(model_id: str, params: GenerationParams) -> dict: + """Build the extra kwargs dict to pass to ``litellm.completion``. + + Maps :class:`GenerationParams` → litellm-shaped parameters, gating + reasoning/thinking fields by :func:`~.generation.reasoning_supported`. + Strips and warns when reasoning is requested but the model doesn't support it. + """ + kwargs: dict = { + "max_tokens": params.max_tokens, + "temperature": params.temperature, + } + + if params.top_p is not None: + kwargs["top_p"] = params.top_p + + if params.stop: + kwargs["stop"] = list(params.stop) + + effort = params.reasoning_effort + if effort != "none": + if reasoning_supported(model_id): + kwargs["reasoning_effort"] = effort + else: + _log.warning( + "reasoning_effort=%r requested but model %r does not support " + "reasoning — stripping thinking kwargs.", + effort, + model_id, + ) + elif reasoning_supported(model_id): + # Explicitly disable thinking on supported models so behaviour is + # deterministic rather than provider-default. + kwargs["thinking"] = {"type": "disabled"} + + if params.thinking_budget_tokens is not None: + if reasoning_supported(model_id) and effort != "none": + kwargs["thinking"] = { + "type": "enabled", + "budget_tokens": params.thinking_budget_tokens, + } + elif params.thinking_budget_tokens is not None: + _log.warning( + "thinking_budget_tokens set but reasoning_effort=%r or model " + "%r does not support thinking — ignoring budget.", + effort, + model_id, + ) + + return kwargs + + class LiteLLMBackend(LLMBackend): """LLM backend using the litellm library. @@ -28,24 +87,45 @@ class LiteLLMBackend(LLMBackend): model_id: litellm model string with provider prefix, e.g.: ``"watsonx/meta-llama/llama-3-3-70b-instruct"`` ``"litellm_proxy/GCP/claude-4-sonnet"`` + params: Generation parameters. Defaults to :func:`~.generation.from_env` + when not provided. """ - def __init__(self, model_id: str) -> None: + def __init__( + self, + model_id: str, + params: GenerationParams | None = None, + ) -> None: self._model_id = model_id + self._params: GenerationParams = params if params is not None else from_env() - def generate(self, prompt: str, temperature: float = 0.0) -> str: - return self.generate_with_usage(prompt, temperature).text + def generate( + self, + prompt: str, + temperature: float = 0.0, + *, + params: GenerationParams | None = None, + ) -> str: + return self.generate_with_usage(prompt, temperature, params=params).text def generate_with_usage( - self, prompt: str, temperature: float = 0.0 + self, + prompt: str, + temperature: float = 0.0, + *, + params: GenerationParams | None = None, ) -> LLMResult: import litellm + effective = resolve_params( + self._params, override=params, temperature=temperature + ) + extra = to_litellm_kwargs(self._model_id, effective) + kwargs: dict = { "model": self._model_id, "messages": [{"role": "user", "content": prompt}], - "temperature": temperature, - "max_tokens": 2048, + **extra, } if self._model_id.startswith(_WATSONX_PREFIX): diff --git a/src/llm/tests/__init__.py b/src/llm/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llm/tests/test_generation.py b/src/llm/tests/test_generation.py new file mode 100644 index 000000000..eff24e244 --- /dev/null +++ b/src/llm/tests/test_generation.py @@ -0,0 +1,164 @@ +"""Unit tests for GenerationParams, from_env, resolve_params, reasoning_supported.""" + +from __future__ import annotations + +import pytest + +from llm.generation import ( + GenerationParams, + EFFORT_TO_OPENAI, + from_env, + reasoning_supported, + resolve_params, +) + + +# ── reasoning_supported heuristics ─────────────────────────────────────────── + + +@pytest.mark.parametrize( + "model_id,expected", + [ + ("watsonx/meta-llama/llama-3-3-70b-instruct", False), + ("watsonx/ibm/granite-3-3-8b-instruct", False), + ("litellm_proxy/aws/claude-opus-4-6", True), + ("litellm_proxy/GCP/claude-4-sonnet", True), + ("litellm_proxy/azure/gpt-5.4", True), + ("anthropic/claude-sonnet-4-5", True), + ("openai/o1-preview", True), + ("openai/o3-mini", True), + ("openai/gpt-5", True), + ("some-unknown-model", False), + ("ollama/llama3", False), + ], +) +def test_reasoning_supported(model_id, expected): + assert reasoning_supported(model_id) is expected + + +# ── effort vocab mapping ────────────────────────────────────────────────────── + + +def test_effort_to_openai_map_complete(): + for effort in ("none", "low", "medium", "high", "max"): + assert effort in EFFORT_TO_OPENAI + + +def test_max_maps_to_xhigh(): + assert EFFORT_TO_OPENAI["max"] == "xhigh" + + +def test_none_maps_to_none(): + assert EFFORT_TO_OPENAI["none"] == "none" + + +# ── from_env ───────────────────────────────────────────────────────────────── + + +def test_from_env_defaults(monkeypatch): + for var in ( + "LLM_MAX_TOKENS", + "LLM_TEMPERATURE", + "LLM_TOP_P", + "LLM_REASONING_EFFORT", + "LLM_THINKING_BUDGET_TOKENS", + "LLM_STOP", + ): + monkeypatch.delenv(var, raising=False) + p = from_env() + assert p.max_tokens == 4096 + assert p.temperature == 0.0 + assert p.reasoning_effort == "none" + assert p.top_p is None + assert p.thinking_budget_tokens is None + assert p.stop is None + + +def test_from_env_max_tokens(monkeypatch): + monkeypatch.setenv("LLM_MAX_TOKENS", "8192") + assert from_env().max_tokens == 8192 + + +def test_from_env_temperature(monkeypatch): + monkeypatch.setenv("LLM_TEMPERATURE", "0.7") + assert from_env().temperature == pytest.approx(0.7) + + +def test_from_env_top_p(monkeypatch): + monkeypatch.setenv("LLM_TOP_P", "0.9") + assert from_env().top_p == pytest.approx(0.9) + + +def test_from_env_reasoning_effort_valid(monkeypatch): + for effort in ("none", "low", "medium", "high", "max"): + monkeypatch.setenv("LLM_REASONING_EFFORT", effort) + assert from_env().reasoning_effort == effort + + +def test_from_env_reasoning_effort_invalid_falls_back(monkeypatch, caplog): + monkeypatch.setenv("LLM_REASONING_EFFORT", "ultra") + import logging + + with caplog.at_level(logging.WARNING, logger="llm.generation"): + p = from_env() + assert p.reasoning_effort == "none" + assert "LLM_REASONING_EFFORT" in caplog.text + + +def test_from_env_thinking_budget_tokens(monkeypatch): + monkeypatch.setenv("LLM_THINKING_BUDGET_TOKENS", "2048") + assert from_env().thinking_budget_tokens == 2048 + + +def test_from_env_stop_comma_separated(monkeypatch): + monkeypatch.setenv("LLM_STOP", ",END, STOP ") + p = from_env() + assert p.stop == ("", "END", "STOP") + + +def test_from_env_stop_empty_skipped(monkeypatch): + monkeypatch.setenv("LLM_STOP", " , ") + assert from_env().stop is None + + +# ── resolve_params ──────────────────────────────────────────────────────────── + + +def test_resolve_params_no_override(): + base = GenerationParams(max_tokens=512, temperature=0.5) + result = resolve_params(base) + assert result is base + + +def test_resolve_params_temperature_always_wins(): + base = GenerationParams(temperature=0.5) + result = resolve_params(base, temperature=0.9) + assert result.temperature == pytest.approx(0.9) + + +def test_resolve_params_temperature_zero_wins(): + base = GenerationParams(temperature=0.8) + result = resolve_params(base, temperature=0.0) + assert result.temperature == pytest.approx(0.0) + + +def test_resolve_params_override_non_none_replaces(): + base = GenerationParams(max_tokens=1024) + override = GenerationParams(max_tokens=8192) + result = resolve_params(base, override=override) + assert result.max_tokens == 8192 + + +def test_resolve_params_none_field_does_not_replace(): + base = GenerationParams(top_p=0.95) + override = GenerationParams() # top_p=None by default + result = resolve_params(base, override=override) + assert result.top_p == pytest.approx(0.95) + + +def test_resolve_params_frozen(): + base = GenerationParams() + result = resolve_params(base, temperature=0.5) + assert result is not base + with pytest.raises((AttributeError, TypeError)): + result.temperature = 0.1 # type: ignore[misc] diff --git a/src/llm/tests/test_litellm_kwargs.py b/src/llm/tests/test_litellm_kwargs.py new file mode 100644 index 000000000..a952405d8 --- /dev/null +++ b/src/llm/tests/test_litellm_kwargs.py @@ -0,0 +1,147 @@ +"""Unit tests for to_litellm_kwargs — kwargs emitted per model × effort matrix.""" + +from __future__ import annotations + +import logging + +import pytest + +from llm.generation import GenerationParams +from llm.litellm import to_litellm_kwargs + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +_CLAUDE_MODEL = "litellm_proxy/aws/claude-opus-4-6" +_WATSONX_MODEL = "watsonx/meta-llama/llama-3-3-70b-instruct" +_GPT_MODEL = "litellm_proxy/azure/gpt-5.4" +_UNKNOWN_MODEL = "some-custom-model" + + +# ── basic fields always present ─────────────────────────────────────────────── + + +def test_max_tokens_present(): + p = GenerationParams(max_tokens=8192) + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["max_tokens"] == 8192 + + +def test_temperature_present(): + p = GenerationParams(temperature=0.7) + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["temperature"] == pytest.approx(0.7) + + +def test_top_p_omitted_when_none(): + p = GenerationParams() + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert "top_p" not in kw + + +def test_top_p_included_when_set(): + p = GenerationParams(top_p=0.9) + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["top_p"] == pytest.approx(0.9) + + +def test_stop_omitted_when_none(): + p = GenerationParams() + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert "stop" not in kw + + +def test_stop_included_when_set(): + p = GenerationParams(stop=("END", "")) + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["stop"] == ["END", ""] + + +# ── thinking: supported model ───────────────────────────────────────────────── + + +def test_thinking_disabled_explicit_when_effort_none_on_claude(): + """effort=none on a Claude model → emit thinking=disabled to be deterministic.""" + p = GenerationParams(reasoning_effort="none") + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw.get("thinking") == {"type": "disabled"} + assert "reasoning_effort" not in kw + + +def test_reasoning_effort_forwarded_on_claude(): + p = GenerationParams(reasoning_effort="medium") + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["reasoning_effort"] == "medium" + + +def test_thinking_budget_tokens_sets_thinking_dict(): + p = GenerationParams(reasoning_effort="medium", thinking_budget_tokens=2048) + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["thinking"] == {"type": "enabled", "budget_tokens": 2048} + + +def test_effort_high_on_claude(): + p = GenerationParams(reasoning_effort="high") + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["reasoning_effort"] == "high" + assert "thinking" not in kw or kw.get("thinking") != {"type": "disabled"} + + +def test_effort_max_on_claude(): + p = GenerationParams(reasoning_effort="max") + kw = to_litellm_kwargs(_CLAUDE_MODEL, p) + assert kw["reasoning_effort"] == "max" + + +# ── thinking: unsupported model (WatsonX) ──────────────────────────────────── + + +def test_no_thinking_kwargs_on_watsonx_when_none(): + p = GenerationParams(reasoning_effort="none") + kw = to_litellm_kwargs(_WATSONX_MODEL, p) + assert "thinking" not in kw + assert "reasoning_effort" not in kw + + +def test_reasoning_stripped_with_warning_on_watsonx(caplog): + p = GenerationParams(reasoning_effort="medium") + with caplog.at_level(logging.WARNING, logger="llm.litellm"): + kw = to_litellm_kwargs(_WATSONX_MODEL, p) + assert "reasoning_effort" not in kw + assert "thinking" not in kw + assert "stripping" in caplog.text.lower() + + +# ── thinking: unknown model ─────────────────────────────────────────────────── + + +def test_reasoning_stripped_with_warning_on_unknown_model(caplog): + p = GenerationParams(reasoning_effort="low") + with caplog.at_level(logging.WARNING, logger="llm.litellm"): + kw = to_litellm_kwargs(_UNKNOWN_MODEL, p) + assert "reasoning_effort" not in kw + assert "stripping" in caplog.text.lower() + + +# ── watsonx: max_tokens still included ─────────────────────────────────────── + + +def test_max_tokens_on_watsonx(): + p = GenerationParams(max_tokens=4096) + kw = to_litellm_kwargs(_WATSONX_MODEL, p) + assert kw["max_tokens"] == 4096 + + +def test_temperature_on_watsonx(): + p = GenerationParams(temperature=0.3) + kw = to_litellm_kwargs(_WATSONX_MODEL, p) + assert kw["temperature"] == pytest.approx(0.3) + + +# ── proxy GPT ───────────────────────────────────────────────────────────────── + + +def test_reasoning_effort_forwarded_on_gpt(): + p = GenerationParams(reasoning_effort="high") + kw = to_litellm_kwargs(_GPT_MODEL, p) + assert kw["reasoning_effort"] == "high" diff --git a/src/servers/fmsr/main.py b/src/servers/fmsr/main.py index 1638b7629..dea70d198 100644 --- a/src/servers/fmsr/main.py +++ b/src/servers/fmsr/main.py @@ -104,7 +104,9 @@ def _build_llm(): missing = [v for v in ("LITELLM_API_KEY", "LITELLM_BASE_URL") if not os.environ.get(v)] if missing: raise RuntimeError(f"Missing env vars for LiteLLM: {missing}") - return LiteLLMBackend(model_id) + from llm import from_env as _gen_from_env + + return LiteLLMBackend(model_id, params=_gen_from_env()) try: