diff --git a/python/examples/prompt_construction/litellm/README.md b/python/examples/prompt_construction/litellm/README.md new file mode 100644 index 0000000..6d4db07 --- /dev/null +++ b/python/examples/prompt_construction/litellm/README.md @@ -0,0 +1,210 @@ +# LiteLLM examples + +This directory contains three small Context Compiler + LiteLLM integration examples: + +- `basic.py`: compiler-only flow (no directive drafter) +- `response_format.py`: host-side LiteLLM `response_format` selection from saved compiler state +- `with_directive_drafter.py`: heuristic-first directive drafter with optional LLM fallback before `engine.step(...)` + +## Requirements + +```shell +pip install "context-compiler[integrations]" +export OPENAI_API_KEY=... +``` + +Checkpoint continuation in these examples requires `context-compiler>=0.7.0`. + +For `with_directive_drafter.py`: + +```shell +pip install context-compiler-directive-drafter +``` + +## Quickstart (copy/paste) + +```shell +pip install "context-compiler[integrations]" +export OPENAI_API_KEY=... +export MODEL=openai/gpt-4o-mini +python - <<'PY' +from context_compiler import create_engine +from examples.integrations.litellm.basic import handle_turn +engine = create_engine() +print(handle_turn("set premise concise replies", engine)) +PY +``` + +For directive-drafter behavior: + +```shell +pip install context-compiler-directive-drafter +export OPENAI_API_KEY=... +export MODEL=openai/gpt-4o-mini +python - <<'PY' +from context_compiler import create_engine +from examples.integrations.litellm.with_directive_drafter import handle_turn +engine = create_engine() +print(handle_turn("set premise to concise replies", engine)) +PY +``` + +This near-miss input should return `clarify` instead of being rewritten. + +For host-side response shape selection: + +```shell +pip install "context-compiler[integrations]" +export OPENAI_API_KEY=... +export MODEL=openai/gpt-4o-mini +python - <<'PY' +from context_compiler import create_engine +from examples.integrations.litellm.response_format import plan_turn +engine = create_engine() +engine.step("use compact_summary") +print(plan_turn("Summarize the release notes.", engine)) +PY +``` + +## Environment configuration + +Required (normal `openai` mode): + +```shell +export OPENAI_API_KEY=... +``` + +Optional: + +```shell +export PROVIDER=openai +export MODEL=openai/gpt-4o-mini +export PREPROCESSOR_MODEL=openai/gpt-4o-mini +export OPENAI_BASE_URL=... +export PREPROCESSOR_PROMPT_PROFILE=default +``` + +Provider mode contract (`PROVIDER`) is strict: + +- `openai` +- `ollama` +- `openai_compatible` + +Unknown values hard fail with a validation error. + +Resolution precedence: + +1. `OPENAI_BASE_URL` override +2. `PROVIDER` +3. default (`openai`) + +Operational behavior by mode: + +- `openai` + - default `base_url`: `https://api.openai.com/v1` + - requires `OPENAI_API_KEY` +- `ollama` + - default `base_url`: `http://localhost:11434` + - API key optional +- `openai_compatible` + - requires `OPENAI_BASE_URL` when explicitly selected with `PROVIDER` + - API key requirement depends on endpoint + +Startup emits one concise config line showing resolved `mode`, `base_url`, `model`, +and resolution `source` (`default`, `PROVIDER`, or `OPENAI_BASE_URL override`). + +`MODEL` and `PREPROCESSOR_MODEL` use LiteLLM format: `/`. +`PREPROCESSOR_MODEL` is optional and defaults to `MODEL`. + +For heuristic-first usage, keep `PREPROCESSOR_PROMPT_PROFILE=default`. +Use `llama` only for LLM-only preprocessing with Llama-family models. + +## Usage pattern + +You can import these files as integration references in host applications. + +- Import `handle_turn(...)` from either `basic.py` or `with_directive_drafter.py`. +- Create and retain an engine instance in host/session state. +- Pass each user input through `handle_turn(user_input, engine)`. +- Optional checkpointing: pass `session_key=...`. + The example restores checkpoint data before the first `engine.step(...)` and + saves checkpoint data after `update`/`clarify`. +- In this example, checkpoint/session storage is in-memory only. + State lasts only for the current process. To survive restarts, store + checkpoints in external storage (DB/Redis/etc.). +- Display the returned assistant text. + +Note: In these LiteLLM examples, `update` is rendered locally and does not call +the downstream LLM. This makes state changes explicit. Production apps may +choose different rendering behavior. + +## Response format example boundary + +`response_format.py` shows a different integration boundary from prompt reinjection: + +- Context Compiler owns authoritative state. +- The host reads saved policy state and selects a LiteLLM `response_format` or omits it. +- LiteLLM owns model invocation and provider behavior. +- Context Compiler does not call LiteLLM on its own. +- Context Compiler does not validate model output. +- Context Compiler does not generate schemas dynamically. +- This is application-layer use of authoritative state, not compiler semantics. + +## Troubleshooting + +- `litellm is required`: install `context-compiler[integrations]` (and `context-compiler-directive-drafter` for directive-drafter flows). +- `OPENAI_API_KEY is required in openai mode`: export a key or use `ollama` / explicit endpoint override. +- `Invalid PROVIDER value ...`: set `PROVIDER` to one of `openai`, `ollama`, `openai_compatible`. +- `OPENAI_BASE_URL is required when PROVIDER=openai_compatible`: set an explicit endpoint URL. +- model/provider errors (`Model not found`, provider auth errors): confirm `MODEL` uses LiteLLM format and provider credentials are valid. + +## Basic vs directive-drafter behavior + +- Basic: passes raw user input to `engine.step(...)`. +- With directive drafter: runs heuristic directive drafter first. + - If heuristic returns a directive, that directive is passed to `engine.step(...)`. + - If heuristic does not produce a directive (`no_directive` or `unknown`), LLM fallback drafting runs. + - If fallback yields nothing usable or errors, behavior safely remains equivalent to basic. + - If `engine.has_pending_clarification()` is true, bypass directive drafting and pass raw input directly to `engine.step(...)`. + - Behavior is reject-first and does not broaden the directive grammar. + +Decision flow in both examples: +- `passthrough`: call the model with normal input. +- `clarify`: show `prompt_to_user`; do not treat state as changed. +- `update`: state changed; use updated state for the next model call. + +Decision flow in `response_format.py`: +- `passthrough`: let the host decide whether to send `response_format`. +- `clarify`: show `prompt_to_user`; do not call LiteLLM. +- `update`: state changed; the next host request may use a different `response_format`. + +## Example checks + +- Near-miss passthrough (`with_directive_drafter.py`): + - `set premise to concise replies` is not rewritten by the directive drafter and is passed through unchanged. + - Engine returns clarify (`Did you mean 'set premise concise replies'?`). +- Lifecycle enforcement (both): + - `change premise to formal tone` with no premise -> clarify (`set premise ...` first). +- Conflict behavior (both): + - `use docker` then `prohibit docker` -> conflict clarify. +- Replacement precondition (both): + - `use podman instead of docker` without prior `use docker` -> replacement clarify. +- Directive-adjacent abstain (`with_directive_drafter.py`): + - `change premise concise replies` is classified as `unknown`, not rewritten, and handled by engine clarify. +- Host-side request shaping (`response_format.py`): + - `use compact_summary` -> host selects compact-summary `response_format`. + - `use action_plan` -> host selects action-plan `response_format`. + - `prohibit compact_summary` -> host omits that `response_format`. + +## Optional smoke run for `response_format.py` + +```shell +export RUN_LITELLM_SMOKE=1 +export PROVIDER=ollama +export MODEL=ollama/qwen2.5:1.5b-instruct +uv run python examples/integrations/litellm/response_format.py +``` + +For local Ollama smoke runs in this repo, `PROVIDER=ollama` is required. A +`MODEL=ollama/...` value by itself still follows the default OpenAI provider +path. diff --git a/python/examples/prompt_construction/litellm/basic.py b/python/examples/prompt_construction/litellm/basic.py new file mode 100644 index 0000000..49da0f9 --- /dev/null +++ b/python/examples/prompt_construction/litellm/basic.py @@ -0,0 +1,349 @@ +"""Minimal LiteLLM integration with Context Compiler. + +Flow: +1. Call engine.step(user_input) +2. clarify -> return prompt_to_user (no model call) +3. update -> return deterministic acknowledgment text (no model call) +4. passthrough -> call LiteLLM with compiled state + user input + +Intended host usage: +- collect user input +- call handle_turn(user_input, engine) +- display returned assistant text +""" + +import logging +import re +from collections.abc import Callable, Mapping, Sequence +from importlib import import_module +from typing import TypedDict, cast + +from context_compiler import ( + DECISION_CLARIFY, + DECISION_PASSTHROUGH, + DECISION_UPDATE, + POLICY_PROHIBIT, + POLICY_USE, + State, + get_clarify_prompt, + get_decision_state, + get_policy_items, + get_premise_value, + is_clarify, + is_passthrough, + is_update, +) +from context_compiler.engine import Engine +from context_compiler.observability import build_trace + +try: + from host_support import is_confirmation_text +except ImportError: + import host_support.confirmation as _confirmation + + is_confirmation_text = _confirmation.is_confirmation_text + +try: + from host_support.confirmation import summarize_confirmation_update_from_checkpoint +except ImportError: + from host_support.confirmation import ( + summarize_confirmation_update as _summarize_confirmation_update_from_pending, + ) + + def summarize_confirmation_update_from_checkpoint(user_input: str, checkpoint: object) -> str: + pending = checkpoint.get("pending") if isinstance(checkpoint, dict) else None + return _summarize_confirmation_update_from_pending(user_input, pending) + + +try: + from host_support import print_startup_config, resolve_provider_config +except ImportError: + from host_support.provider_mode import print_startup_config, resolve_provider_config + +logger = logging.getLogger(__name__) +# Example-only in-memory checkpoint store. +# This keeps continuation state only for the current process lifetime. +# Real deployments should persist checkpoints externally (DB/Redis/etc.), +# or restart continuity for pending flows will be lost. +_CHECKPOINTS_BY_SESSION_KEY: dict[str, str] = {} +_RESTORED_ENGINE_BY_SESSION_KEY: dict[str, int] = {} +_NEGATIVE_CONFIRMATION_TOKENS = {"no", "nope", "no thanks"} +_TRAILING_CONFIRM_PUNCT_RE = re.compile(r"[.,!?]+$") +SHOW_CONTEXT_COMPILER_TRACE = False + + +class _LiteLLMCallKwargs(TypedDict, total=False): + model: str + messages: list[dict[str, str]] + api_key: str + temperature: float + api_base: str + + +def _extract_response_content(response: object) -> str | None: + if isinstance(response, Mapping): + choices = response.get("choices") + if isinstance(choices, Sequence) and choices: + first = choices[0] + if isinstance(first, Mapping): + message = first.get("message") + if isinstance(message, Mapping): + content = message.get("content") + if isinstance(content, str): + return content + + choices_attr = getattr(response, "choices", None) + if isinstance(choices_attr, Sequence) and choices_attr: + first = choices_attr[0] + message_attr = getattr(first, "message", None) + content_attr = getattr(message_attr, "content", None) + if isinstance(content_attr, str): + return content_attr + + return None + + +def _render_compiled_state_contract(compiled_state: State) -> str: + premise = get_premise_value(compiled_state) + use_items = sorted(get_policy_items(compiled_state, POLICY_USE)) + prohibit_items = sorted(get_policy_items(compiled_state, POLICY_PROHIBIT)) + + lines: list[str] = ["The following constraints are authoritative."] + if premise: + lines.append(f"Current premise: {premise}.") + if use_items: + lines.append("Items marked use: " + ", ".join(use_items) + ".") + if prohibit_items: + lines.append("Items marked prohibit: " + ", ".join(prohibit_items) + ".") + lines.append("If user text conflicts with constraints, follow constraints exactly.") + + return "Host policy contract:\n" + "\n".join(f"- {line}" for line in lines) + + +def _build_messages(user_input: str, compiled_state: State) -> list[dict[str, str]]: + return [ + { + "role": "system", + "content": "You are a helpful assistant.\n" + + _render_compiled_state_contract(compiled_state), + }, + {"role": "user", "content": user_input}, + ] + + +def _call_litellm(messages: list[dict[str, str]]) -> str: + try: + litellm_module = import_module("litellm") + except ModuleNotFoundError as exc: + raise RuntimeError("litellm is required. Install with: pip install litellm") from exc + completion_fn = cast(Callable[..., object], litellm_module.completion) + + config = resolve_provider_config(default_model="openai/gpt-4o-mini") + print_startup_config(config, logger=logger) + + kwargs: _LiteLLMCallKwargs = { + "model": config.model, + "messages": messages, + "temperature": 0, + "api_base": config.base_url, + } + if config.api_key: + kwargs["api_key"] = config.api_key + + response = completion_fn(**kwargs) + content = _extract_response_content(response) + if content is None: + raise RuntimeError("LiteLLM response missing choices[0].message.content") + return content + + +def _restore_session_checkpoint_if_needed(engine: Engine, session_key: str | None) -> None: + if session_key is None: + return + engine_id = id(engine) + if _RESTORED_ENGINE_BY_SESSION_KEY.get(session_key) == engine_id: + return + + checkpoint = _CHECKPOINTS_BY_SESSION_KEY.get(session_key) + if checkpoint is not None: + engine.import_checkpoint_json(checkpoint) + _RESTORED_ENGINE_BY_SESSION_KEY[session_key] = engine_id + + +def _persist_session_checkpoint_if_needed( + engine: Engine, kind: str, session_key: str | None +) -> None: + if session_key is None: + return + if kind not in {DECISION_UPDATE, DECISION_CLARIFY}: + return + _CHECKPOINTS_BY_SESSION_KEY[session_key] = engine.export_checkpoint_json() + + +def _normalize_confirmation_for_summary(value: str) -> str: + normalized = value.strip().lower() + normalized = re.sub(r"\s+", " ", normalized) + normalized = _TRAILING_CONFIRM_PUNCT_RE.sub("", normalized).strip() + return re.sub(r"\s+", " ", normalized) + + +def _render_item_label(value: str) -> str: + return re.sub(r"\s+", " ", value).strip().lower() + + +def _near_miss_directive_clarify(value: str) -> str | None: + normalized = re.sub(r"\s+", " ", value.strip()) + lower = normalized.lower() + + if lower in {"reset premise", "reset premises", "clear premises"}: + return "Unknown directive.\nUse 'clear premise' or 'reset policies'." + if lower.startswith("set premise to "): + return "Invalid premise syntax.\nUse 'set premise '." + if lower.startswith("change premise ") and not lower.startswith("change premise to "): + return "Invalid premise syntax.\nUse 'change premise to '." + return None + + +def _summarize_confirmation_update(user_input: str, checkpoint: object) -> str: + summarize_fn: Callable[[str, object], str] = summarize_confirmation_update_from_checkpoint + return summarize_fn(user_input, checkpoint) + + +def _summarize_update_from_input(user_input: str) -> str: + normalized = re.sub(r"\s+", " ", user_input.strip()) + lower = normalized.lower() + + if lower == "clear state": + return "State cleared." + if lower == "clear premise": + return "Premise cleared." + if lower == "reset policies": + return "Policies reset." + + replacement_match = re.match( + r"^use\s+(.+?)\s+instead\s+of\s+(.+)$", normalized, flags=re.IGNORECASE + ) + if replacement_match is not None: + item = _render_item_label(replacement_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + use_match = re.match(r"^use\s+(.+)$", normalized, flags=re.IGNORECASE) + if use_match is not None: + item = _render_item_label(use_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + prohibit_match = re.match(r"^prohibit\s+(.+)$", normalized, flags=re.IGNORECASE) + if prohibit_match is not None: + item = _render_item_label(prohibit_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Prohibit {item}." + + remove_policy_match = re.match(r"^remove\s+policy\s+(.+)$", normalized, flags=re.IGNORECASE) + if remove_policy_match is not None: + item = _render_item_label(remove_policy_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Removed policy {item}." + + return "State updated." + + +def _append_trace( + response_text: str, + *, + original_input: str, + compiler_input: str, + decision: object, + state_before: object, + state_after: object, + llm_called: bool, +) -> str: + if not SHOW_CONTEXT_COMPILER_TRACE: + return response_text + trace_text = build_trace( + original_input=original_input, + compiler_input=compiler_input, + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=llm_called, + ) + return f"{response_text}\n\n{trace_text}" + + +def handle_turn(user_input: str, engine: Engine, *, session_key: str | None = None) -> str: + _restore_session_checkpoint_if_needed(engine, session_key) + state_before = engine.state + has_pending_before = engine.has_pending_clarification() + checkpoint_before = engine.export_checkpoint() if has_pending_before else None + logger.debug("litellm_basic: engine_input=%s", f"user_input len={len(user_input)}") + decision = engine.step(user_input) + if is_clarify(decision): + kind = DECISION_CLARIFY + elif is_update(decision): + kind = DECISION_UPDATE + else: + kind = DECISION_PASSTHROUGH + logger.debug("litellm_basic: decision=%s", kind) + near_miss_prompt = _near_miss_directive_clarify(user_input) + + if is_clarify(decision): + _persist_session_checkpoint_if_needed(engine, kind, session_key) + response_text = near_miss_prompt or get_clarify_prompt(decision) or "" + return _append_trace( + response_text, + original_input=user_input, + compiler_input=user_input, + decision=decision, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + if near_miss_prompt is not None and is_passthrough(decision): + return _append_trace( + near_miss_prompt, + original_input=user_input, + compiler_input=user_input, + decision={"kind": DECISION_CLARIFY, "prompt_to_user": near_miss_prompt}, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + _persist_session_checkpoint_if_needed(engine, kind, session_key) + if is_update(decision) and is_confirmation_text(user_input) and checkpoint_before is not None: + response_text = _summarize_confirmation_update(user_input, checkpoint_before) + return _append_trace( + response_text, + original_input=user_input, + compiler_input=user_input, + decision=decision, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + if is_update(decision): + response_text = _summarize_update_from_input(user_input) + return _append_trace( + response_text, + original_input=user_input, + compiler_input=user_input, + decision=decision, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + + decision_state = get_decision_state(decision) + compiled_state = decision_state if decision_state is not None else engine.state + messages = _build_messages(user_input, compiled_state) + response_text = _call_litellm(messages) + return _append_trace( + response_text, + original_input=user_input, + compiler_input=user_input, + decision=decision, + state_before=state_before, + state_after=compiled_state, + llm_called=True, + ) diff --git a/python/examples/prompt_construction/litellm/with_directive_drafter.py b/python/examples/prompt_construction/litellm/with_directive_drafter.py new file mode 100644 index 0000000..e35feb9 --- /dev/null +++ b/python/examples/prompt_construction/litellm/with_directive_drafter.py @@ -0,0 +1,477 @@ +"""LiteLLM integration with optional directive drafter before Context Compiler. + +Flow: +1. Extract user input +2. Run heuristic directive drafter +3. If no directive, run LLM fallback directive drafter using prompt files +4. Pass directive (or original input) to engine.step(...) +5. clarify -> return prompt_to_user (no model call) +6. update -> return deterministic acknowledgment text (no model call) +7. passthrough -> call LiteLLM with compiled state + user input + +Intended host usage: +- collect user input +- call handle_turn(user_input, engine) +- display returned assistant text +""" + +import logging +import os +import re +from collections.abc import Callable, Mapping, Sequence +from importlib import import_module +from importlib.resources import as_file, files +from importlib.resources.abc import Traversable +from typing import TypedDict, cast + +from context_compiler import ( + DECISION_CLARIFY, + DECISION_PASSTHROUGH, + DECISION_UPDATE, + POLICY_PROHIBIT, + POLICY_USE, + State, + get_clarify_prompt, + get_decision_state, + get_policy_items, + get_premise_value, + is_clarify, + is_passthrough, + is_update, +) +from context_compiler.engine import Engine +from context_compiler.observability import build_trace +from context_compiler_directive_drafter import ( + PREPROCESS_OUTCOME_DIRECTIVE, + parse_preprocessor_output, + preprocess_heuristic, + render_prompt, +) + +try: + from host_support import is_confirmation_text +except ImportError: + import host_support.confirmation as _confirmation + + is_confirmation_text = _confirmation.is_confirmation_text + +try: + from host_support.confirmation import summarize_confirmation_update_from_checkpoint +except ImportError: + from host_support.confirmation import ( + summarize_confirmation_update as _summarize_confirmation_update_from_pending, + ) + + def summarize_confirmation_update_from_checkpoint(user_input: str, checkpoint: object) -> str: + pending = checkpoint.get("pending") if isinstance(checkpoint, dict) else None + return _summarize_confirmation_update_from_pending(user_input, pending) + + +try: + from host_support import print_startup_config, resolve_provider_config +except ImportError: + from host_support.provider_mode import print_startup_config, resolve_provider_config + +logger = logging.getLogger(__name__) + +_PROMPTS_DIR = files("context_compiler_directive_drafter").joinpath("prompts") +# Example-only in-memory checkpoint store. +# This keeps continuation state only for the current process lifetime. +# Real deployments should persist checkpoints externally (DB/Redis/etc.), +# or restart continuity for pending flows will be lost. +_CHECKPOINTS_BY_SESSION_KEY: dict[str, str] = {} +_RESTORED_ENGINE_BY_SESSION_KEY: dict[str, int] = {} +_NEGATIVE_CONFIRMATION_TOKENS = {"no", "nope", "no thanks"} +_TRAILING_CONFIRM_PUNCT_RE = re.compile(r"[.,!?]+$") +SHOW_CONTEXT_COMPILER_TRACE = False + + +def _is_directive_shaped_input(message: str) -> bool: + normalized = re.sub(r"\s+", " ", message.strip()).lower() + return ( + normalized.startswith("use") + or normalized.startswith("prohibit") + or normalized.startswith("remove policy") + or normalized.startswith("set premise") + or normalized.startswith("change premise") + or normalized.startswith("clear") + or normalized.startswith("reset") + ) + + +class _LiteLLMCallKwargs(TypedDict, total=False): + model: str + messages: list[dict[str, str]] + api_key: str + temperature: float + api_base: str + + +def _extract_response_content(response: object) -> str | None: + if isinstance(response, Mapping): + choices = response.get("choices") + if isinstance(choices, Sequence) and choices: + first = choices[0] + if isinstance(first, Mapping): + message = first.get("message") + if isinstance(message, Mapping): + content = message.get("content") + if isinstance(content, str): + return content + + choices_attr = getattr(response, "choices", None) + if isinstance(choices_attr, Sequence) and choices_attr: + first = choices_attr[0] + message_attr = getattr(first, "message", None) + content_attr = getattr(message_attr, "content", None) + if isinstance(content_attr, str): + return content_attr + + return None + + +def _get_litellm_completion() -> Callable[..., object]: + litellm_module = import_module("litellm") + return cast(Callable[..., object], litellm_module.completion) + + +def _render_compiled_state_contract(compiled_state: State) -> str: + premise = get_premise_value(compiled_state) + use_items = sorted(get_policy_items(compiled_state, POLICY_USE)) + prohibit_items = sorted(get_policy_items(compiled_state, POLICY_PROHIBIT)) + + lines: list[str] = ["The following constraints are authoritative."] + if premise: + lines.append(f"Current premise: {premise}.") + if use_items: + lines.append("Items marked use: " + ", ".join(use_items) + ".") + if prohibit_items: + lines.append("Items marked prohibit: " + ", ".join(prohibit_items) + ".") + lines.append("If user text conflicts with constraints, follow constraints exactly.") + + return "Host policy contract:\n" + "\n".join(f"- {line}" for line in lines) + + +def _build_messages(user_input: str, compiled_state: State) -> list[dict[str, str]]: + return [ + { + "role": "system", + "content": "You are a helpful assistant.\n" + + _render_compiled_state_contract(compiled_state), + }, + {"role": "user", "content": user_input}, + ] + + +def _call_litellm(messages: list[dict[str, str]]) -> str: + try: + completion = _get_litellm_completion() + except ModuleNotFoundError as exc: + raise RuntimeError("litellm is required. Install with: pip install litellm") from exc + + config = resolve_provider_config(default_model="openai/gpt-4o-mini") + print_startup_config(config, logger=logger) + + kwargs: _LiteLLMCallKwargs = { + "model": config.model, + "messages": messages, + "temperature": 0, + "api_base": config.base_url, + } + if config.api_key: + kwargs["api_key"] = config.api_key + + response = completion(**kwargs) + content = _extract_response_content(response) + if content is None: + raise RuntimeError("LiteLLM response missing choices[0].message.content") + return content + + +def _prompt_file_path() -> Traversable: + profile = os.getenv("PREPROCESSOR_PROMPT_PROFILE", "default").strip().lower() + if profile == "llama": + return _PROMPTS_DIR.joinpath("llama.txt") + return _PROMPTS_DIR.joinpath("default.txt") + + +def _llm_fallback_preprocess(message: str, state: State) -> str | None: + with as_file(_prompt_file_path()) as prompt_path: + prompt = render_prompt(prompt_path, state) + if prompt is None: + return None + + try: + completion = _get_litellm_completion() + except ModuleNotFoundError: + return None + + try: + config = resolve_provider_config(default_model="openai/gpt-4o-mini") + except RuntimeError: + return None + if config.mode == "openai" and not config.api_key: + return None + preprocessor_model = os.getenv("PREPROCESSOR_MODEL", "").strip() + if not preprocessor_model: + preprocessor_model = os.getenv("MODEL", "openai/gpt-4o-mini") + + kwargs: _LiteLLMCallKwargs = { + "model": preprocessor_model, + "messages": [ + {"role": "system", "content": prompt}, + {"role": "user", "content": message}, + ], + "temperature": 0, + "api_base": config.base_url, + } + if config.api_key: + kwargs["api_key"] = config.api_key + + try: + response = completion(**kwargs) + raw_output = _extract_response_content(response) + except Exception: + return None + + parsed = parse_preprocessor_output(raw_output, source_input=message) + if parsed is None: + return None + return parsed + + +def _preprocess_user_input(message: str, state: State) -> str | None: + # Heuristic first (fast + high precision), then optional LLM fallback. + try: + heuristic_result = preprocess_heuristic(message) + logger.debug("preprocessor: heuristic_outcome=%s", heuristic_result["outcome"]) + if ( + heuristic_result["outcome"] == PREPROCESS_OUTCOME_DIRECTIVE + and heuristic_result["directive"] + ): + parsed = parse_preprocessor_output(heuristic_result["directive"]) + logger.debug("preprocessor: heuristic_directive=%r", heuristic_result["directive"]) + if parsed is not None: + return parsed + except Exception: + logger.debug("preprocessor: heuristic_exception", exc_info=True) + + if _is_directive_shaped_input(message): + return None + + try: + fallback_directive = _llm_fallback_preprocess(message, state) + logger.debug("preprocessor: fallback_directive=%r", fallback_directive) + return fallback_directive + except Exception: + # Safe no-op fallback: if preprocessor path fails, preserve basic behavior. + return None + + +def _restore_session_checkpoint_if_needed(engine: Engine, session_key: str | None) -> None: + if session_key is None: + return + engine_id = id(engine) + if _RESTORED_ENGINE_BY_SESSION_KEY.get(session_key) == engine_id: + return + + checkpoint = _CHECKPOINTS_BY_SESSION_KEY.get(session_key) + if checkpoint is not None: + engine.import_checkpoint_json(checkpoint) + _RESTORED_ENGINE_BY_SESSION_KEY[session_key] = engine_id + + +def _persist_session_checkpoint_if_needed( + engine: Engine, kind: str, session_key: str | None +) -> None: + if session_key is None: + return + if kind not in {DECISION_UPDATE, DECISION_CLARIFY}: + return + _CHECKPOINTS_BY_SESSION_KEY[session_key] = engine.export_checkpoint_json() + + +def _normalize_confirmation_for_summary(value: str) -> str: + normalized = value.strip().lower() + normalized = re.sub(r"\s+", " ", normalized) + normalized = _TRAILING_CONFIRM_PUNCT_RE.sub("", normalized).strip() + return re.sub(r"\s+", " ", normalized) + + +def _render_item_label(value: str) -> str: + return re.sub(r"\s+", " ", value).strip().lower() + + +def _near_miss_directive_clarify(value: str) -> str | None: + normalized = re.sub(r"\s+", " ", value.strip()) + lower = normalized.lower() + + if lower in {"reset premise", "reset premises", "clear premises"}: + return "Unknown directive.\nUse 'clear premise' or 'reset policies'." + if lower.startswith("set premise to "): + return "Invalid premise syntax.\nUse 'set premise '." + if lower.startswith("change premise ") and not lower.startswith("change premise to "): + return "Invalid premise syntax.\nUse 'change premise to '." + return None + + +def _summarize_confirmation_update(user_input: str, checkpoint: object) -> str: + summarize_fn: Callable[[str, object], str] = summarize_confirmation_update_from_checkpoint + return summarize_fn(user_input, checkpoint) + + +def _summarize_update_from_input(user_input: str) -> str: + normalized = re.sub(r"\s+", " ", user_input.strip()) + lower = normalized.lower() + + if lower == "clear state": + return "State cleared." + if lower == "clear premise": + return "Premise cleared." + if lower == "reset policies": + return "Policies reset." + + replacement_match = re.match( + r"^use\s+(.+?)\s+instead\s+of\s+(.+)$", normalized, flags=re.IGNORECASE + ) + if replacement_match is not None: + item = _render_item_label(replacement_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + use_match = re.match(r"^use\s+(.+)$", normalized, flags=re.IGNORECASE) + if use_match is not None: + item = _render_item_label(use_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + prohibit_match = re.match(r"^prohibit\s+(.+)$", normalized, flags=re.IGNORECASE) + if prohibit_match is not None: + item = _render_item_label(prohibit_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Prohibit {item}." + + remove_policy_match = re.match(r"^remove\s+policy\s+(.+)$", normalized, flags=re.IGNORECASE) + if remove_policy_match is not None: + item = _render_item_label(remove_policy_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Removed policy {item}." + + return "State updated." + + +def _append_trace( + response_text: str, + *, + original_input: str, + compiler_input: str, + preprocessor_output: str | None, + decision: object, + state_before: object, + state_after: object, + llm_called: bool, +) -> str: + if not SHOW_CONTEXT_COMPILER_TRACE: + return response_text + trace_text = build_trace( + original_input=original_input, + compiler_input=compiler_input, + preprocessor_output=preprocessor_output, + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=llm_called, + ) + return f"{response_text}\n\n{trace_text}" + + +def handle_turn(user_input: str, engine: Engine, *, session_key: str | None = None) -> str: + _restore_session_checkpoint_if_needed(engine, session_key) + state_before = engine.state + has_pending_before = engine.has_pending_clarification() + checkpoint_before = engine.export_checkpoint() if has_pending_before else None + preprocessd: str | None = None + if engine.has_pending_clarification(): + compile_input = user_input + else: + preprocessd = _preprocess_user_input(user_input, engine.state) + compile_input = preprocessd if preprocessd else user_input + logger.debug( + "preprocessor: engine_input=%s", + "directive" if preprocessd else f"user_input len={len(user_input)}", + ) + + decision = engine.step(compile_input) + if is_clarify(decision): + kind = DECISION_CLARIFY + elif is_update(decision): + kind = DECISION_UPDATE + else: + kind = DECISION_PASSTHROUGH + logger.debug("preprocessor: decision=%s", kind) + near_miss_prompt = _near_miss_directive_clarify(user_input) + + if is_clarify(decision): + _persist_session_checkpoint_if_needed(engine, kind, session_key) + response_text = near_miss_prompt or get_clarify_prompt(decision) or "" + return _append_trace( + response_text, + original_input=user_input, + compiler_input=compile_input, + preprocessor_output=preprocessd, + decision=decision, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + if near_miss_prompt is not None and is_passthrough(decision): + return _append_trace( + near_miss_prompt, + original_input=user_input, + compiler_input=compile_input, + preprocessor_output=preprocessd, + decision={"kind": DECISION_CLARIFY, "prompt_to_user": near_miss_prompt}, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + _persist_session_checkpoint_if_needed(engine, kind, session_key) + if is_update(decision) and is_confirmation_text(user_input) and checkpoint_before is not None: + response_text = _summarize_confirmation_update(user_input, checkpoint_before) + return _append_trace( + response_text, + original_input=user_input, + compiler_input=compile_input, + preprocessor_output=preprocessd, + decision=decision, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + if is_update(decision): + response_text = _summarize_update_from_input(compile_input) + return _append_trace( + response_text, + original_input=user_input, + compiler_input=compile_input, + preprocessor_output=preprocessd, + decision=decision, + state_before=state_before, + state_after=engine.state, + llm_called=False, + ) + + decision_state = get_decision_state(decision) + compiled_state = decision_state if decision_state is not None else engine.state + messages = _build_messages(user_input, compiled_state) + response_text = _call_litellm(messages) + return _append_trace( + response_text, + original_input=user_input, + compiler_input=compile_input, + preprocessor_output=preprocessd, + decision=decision, + state_before=state_before, + state_after=compiled_state, + llm_called=True, + ) diff --git a/python/examples/schema_selection/litellm_response_format/response_format.py b/python/examples/schema_selection/litellm_response_format/response_format.py new file mode 100644 index 0000000..f70ce59 --- /dev/null +++ b/python/examples/schema_selection/litellm_response_format/response_format.py @@ -0,0 +1,214 @@ +"""Minimal LiteLLM response_format selection from authoritative state. + +Flow: +Context Compiler state -> host response_format decision -> LiteLLM model call. + +This example keeps model execution optional so tests can validate behavior +without a live provider. +""" + +import os +from collections.abc import Callable, Mapping +from importlib import import_module +from typing import Any, TypedDict, cast + +from context_compiler import ( + POLICY_PROHIBIT, + POLICY_USE, + State, + create_engine, + get_clarify_prompt, + get_decision_state, + get_policy_items, + is_clarify, +) +from context_compiler.engine import Engine + +try: + from host_support import print_startup_config, resolve_provider_config +except ImportError: + from host_support.provider_mode import print_startup_config, resolve_provider_config + +COMPACT_SUMMARY_RESPONSE_FORMAT: dict[str, Any] = { + "type": "json_schema", + "json_schema": { + "name": "compact_summary", + "schema": { + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "A compact summary of the answer.", + } + }, + "required": ["summary"], + "additionalProperties": False, + }, + }, +} + +ACTION_PLAN_RESPONSE_FORMAT: dict[str, Any] = { + "type": "json_schema", + "json_schema": { + "name": "action_plan", + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": {"type": "string"}, + "description": "Ordered next steps for the user.", + } + }, + "required": ["steps"], + "additionalProperties": False, + }, + }, +} + +_RESPONSE_FORMAT_BY_ITEM: dict[str, dict[str, Any]] = { + "compact_summary": COMPACT_SUMMARY_RESPONSE_FORMAT, + "action_plan": ACTION_PLAN_RESPONSE_FORMAT, +} + + +class TurnPlan(TypedDict): + decision_kind: str + clarify_prompt: str | None + selected_response_format_item: str | None + response_format: dict[str, Any] | None + + +class _LiteLLMCallKwargs(TypedDict, total=False): + model: str + messages: list[dict[str, str]] + temperature: float + api_base: str + api_key: str + response_format: dict[str, Any] + + +def select_litellm_response_format(state: State) -> tuple[str | None, dict[str, Any] | None]: + """Return (policy_item, response_format) or (None, None) when no safe match exists.""" + + use_items = set(get_policy_items(state, POLICY_USE)) + prohibit_items = set(get_policy_items(state, POLICY_PROHIBIT)) + + for item, response_format in _RESPONSE_FORMAT_BY_ITEM.items(): + if item in use_items and item not in prohibit_items: + return item, response_format + + return None, None + + +def plan_turn(user_input: str, engine: Engine) -> TurnPlan: + """Run compiler step and decide whether to request LiteLLM structured output.""" + + decision = engine.step(user_input) + if is_clarify(decision): + return { + "decision_kind": "clarify", + "clarify_prompt": get_clarify_prompt(decision), + "selected_response_format_item": None, + "response_format": None, + } + + decision_state = get_decision_state(decision) + compiled_state = decision_state if decision_state is not None else engine.state + selected_item, response_format = select_litellm_response_format(compiled_state) + + return { + "decision_kind": str(decision["kind"]), + "clarify_prompt": None, + "selected_response_format_item": selected_item, + "response_format": response_format, + } + + +def _get_litellm_completion() -> Callable[..., object]: + litellm_module = import_module("litellm") + return cast(Callable[..., object], litellm_module.completion) + + +def _extract_response_content(response: object) -> str | None: + if isinstance(response, Mapping): + choices = response.get("choices") + if isinstance(choices, list) and choices: + first = choices[0] + if isinstance(first, Mapping): + message = first.get("message") + if isinstance(message, Mapping): + content = message.get("content") + if isinstance(content, str): + return content + + choices_attr = getattr(response, "choices", None) + if isinstance(choices_attr, list) and choices_attr: + first = choices_attr[0] + message_attr = getattr(first, "message", None) + content_attr = getattr(message_attr, "content", None) + if isinstance(content_attr, str): + return content_attr + + return None + + +def optional_litellm_call( + *, + user_input: str, + response_format: Mapping[str, Any] | None, +) -> str: + """Optional smoke call to LiteLLM. + + If `response_format` is provided, it is passed through unchanged. + """ + + try: + completion = _get_litellm_completion() + except ModuleNotFoundError as exc: + raise RuntimeError("litellm is required. Install with: pip install litellm") from exc + + config = resolve_provider_config(default_model="openai/gpt-4o-mini") + print_startup_config(config) + + kwargs: _LiteLLMCallKwargs = { + "model": config.model, + "messages": [{"role": "user", "content": user_input}], + "temperature": 0, + "api_base": config.base_url, + } + if config.api_key: + kwargs["api_key"] = config.api_key + if response_format is not None: + kwargs["response_format"] = dict(response_format) + + response = completion(**kwargs) + content = _extract_response_content(response) + if content is None: + raise RuntimeError("LiteLLM response missing choices[0].message.content") + return content + + +def main() -> None: + engine = create_engine() + + # Demonstration setup. + engine.step("use compact_summary") + engine.step("prohibit action_plan") + + plan = plan_turn("Summarize what changed in this project.", engine) + print("decision_kind:", plan["decision_kind"]) + print("selected_response_format_item:", plan["selected_response_format_item"]) + print("response_format_selected:", plan["response_format"] is not None) + + # Optional model execution path; disabled by default. + if os.getenv("RUN_LITELLM_SMOKE") == "1": + response = optional_litellm_call( + user_input="Summarize what changed in this project.", + response_format=plan["response_format"], + ) + print("litellm_response:", response) + + +if __name__ == "__main__": + main() diff --git a/python/examples/schema_selection/ollama_structured_output/README.md b/python/examples/schema_selection/ollama_structured_output/README.md new file mode 100644 index 0000000..d5d8129 --- /dev/null +++ b/python/examples/schema_selection/ollama_structured_output/README.md @@ -0,0 +1,58 @@ +# Ollama structured output (host-side selection) + +This example shows a visible host behavior change that is different from prompt reinjection. + +Flow: + +`Context Compiler state -> host schema decision -> Ollama format request -> model call` + +The host reads compiled policy state, picks a JSON Schema (or none), and sends that choice through Ollama's `format` field. + +## What this example guarantees + +- Context Compiler provides deterministic state transitions. +- The host integration decides whether to request a schema. +- Ollama structured output is a runtime request made by the host. +- If policy state is unknown or insufficient, the host requests no schema. + +## What `prohibit shell_command` means here + +- The host will not request the `shell_command` schema. +- The host may still request a different schema when policy supports it (for example, `python_script`). +- This does not block normal language discussion about shell commands. + +## Observable behavior + +Given policy state: + +```text +use python_script +prohibit shell_command +``` + +this host selects `python_script` schema and does not request `shell_command` schema. + +## Test boundary + +Tests verify schema selection behavior only: + +- compiler state -> selected schema (or no schema) +- contradiction handling stays in compiler `clarify` + +Tests do not assert exact model wording. + +## Run without Ollama + +```shell +uv run python examples/integrations/ollama_structured_output/example.py +``` + +## Optional Ollama smoke run + +```shell +export RUN_OLLAMA_SMOKE=1 +export OLLAMA_MODEL=llama3.1 +uv run python examples/integrations/ollama_structured_output/example.py +``` + +When smoke mode is enabled, the host sends the selected JSON Schema through Ollama `format`. diff --git a/python/examples/schema_selection/ollama_structured_output/example.py b/python/examples/schema_selection/ollama_structured_output/example.py new file mode 100644 index 0000000..ef2f3f6 --- /dev/null +++ b/python/examples/schema_selection/ollama_structured_output/example.py @@ -0,0 +1,169 @@ +"""Minimal host-side Ollama structured-output schema selection. + +Flow: +Context Compiler state -> host selection logic -> Ollama `format` JSON Schema. + +This example keeps model execution optional so tests can validate behavior without Ollama. +""" + +import json +import os +import urllib.error +import urllib.request +from collections.abc import Mapping +from typing import Any, TypedDict, cast + +from context_compiler import ( + POLICY_PROHIBIT, + POLICY_USE, + State, + create_engine, + get_clarify_prompt, + get_decision_state, + get_policy_items, + is_clarify, +) +from context_compiler.engine import Engine + +PYTHON_SCRIPT_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "python_script": { + "type": "string", + "description": "A complete Python script.", + } + }, + "required": ["python_script"], + "additionalProperties": False, +} + +SHELL_COMMAND_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "shell_command": { + "type": "string", + "description": "A single shell command.", + } + }, + "required": ["shell_command"], + "additionalProperties": False, +} + +# Small, explicit mapping from policy item -> Ollama `format` schema. +_SCHEMA_BY_ITEM: dict[str, dict[str, Any]] = { + "python_script": PYTHON_SCRIPT_SCHEMA, + "shell_command": SHELL_COMMAND_SCHEMA, +} + + +class TurnPlan(TypedDict): + decision_kind: str + clarify_prompt: str | None + selected_schema_item: str | None + format_schema: dict[str, Any] | None + + +def select_ollama_format_schema(state: State) -> tuple[str | None, dict[str, Any] | None]: + """Return (policy_item, schema) or (None, None) when no safe match exists. + + Unknown/insufficient policy state intentionally selects no schema. + """ + + use_items = set(get_policy_items(state, POLICY_USE)) + prohibit_items = set(get_policy_items(state, POLICY_PROHIBIT)) + + for item, schema in _SCHEMA_BY_ITEM.items(): + if item in use_items and item not in prohibit_items: + return item, schema + + return None, None + + +def plan_turn(user_input: str, engine: Engine) -> TurnPlan: + """Run compiler step and decide whether to request Ollama structured output.""" + + decision = engine.step(user_input) + if is_clarify(decision): + return { + "decision_kind": "clarify", + "clarify_prompt": get_clarify_prompt(decision), + "selected_schema_item": None, + "format_schema": None, + } + + decision_state = get_decision_state(decision) + compiled_state = decision_state if decision_state is not None else engine.state + selected_item, format_schema = select_ollama_format_schema(compiled_state) + + return { + "decision_kind": str(decision["kind"]), + "clarify_prompt": None, + "selected_schema_item": selected_item, + "format_schema": format_schema, + } + + +def optional_ollama_call( + *, + user_input: str, + model: str, + format_schema: Mapping[str, Any] | None, + host: str | None = None, +) -> dict[str, Any]: + """Optional smoke call to Ollama's /api/chat. + + If `format_schema` is provided, it is passed through `format` exactly. + """ + + base_url = host or os.getenv("OLLAMA_BASE_URL") or "http://localhost:11434" + payload: dict[str, Any] = { + "model": model, + "messages": [{"role": "user", "content": user_input}], + "stream": False, + } + if format_schema is not None: + payload["format"] = dict(format_schema) + + request = urllib.request.Request( + url=f"{base_url.rstrip('/')}/api/chat", + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen(request, timeout=30) as response: + raw = response.read().decode("utf-8") + except urllib.error.URLError as exc: + raise RuntimeError(f"Ollama call failed: {exc}") from exc + + decoded = cast(object, json.loads(raw)) + if not isinstance(decoded, dict): + raise RuntimeError("Ollama response must be a JSON object") + return cast(dict[str, Any], decoded) + + +def main() -> None: + engine = create_engine() + + # Demonstration setup. + engine.step("use python_script") + engine.step("prohibit shell_command") + + plan = plan_turn("Write a helper script to parse CSV files.", engine) + print("decision_kind:", plan["decision_kind"]) + print("selected_schema_item:", plan["selected_schema_item"]) + print("format_schema_selected:", plan["format_schema"] is not None) + + # Optional model execution path; disabled by default. + if os.getenv("RUN_OLLAMA_SMOKE") == "1": + response = optional_ollama_call( + user_input="Write a helper script to parse CSV files.", + model=os.getenv("OLLAMA_MODEL", "llama3.1"), + format_schema=plan["format_schema"], + ) + print("ollama_response_keys:", sorted(response.keys())) + + +if __name__ == "__main__": + main() diff --git a/python/reference_integrations/litellm_proxy/README.md b/python/reference_integrations/litellm_proxy/README.md index 8e5b0e6..3d0674b 100644 --- a/python/reference_integrations/litellm_proxy/README.md +++ b/python/reference_integrations/litellm_proxy/README.md @@ -1,10 +1,123 @@ -# LiteLLM Proxy reference integration - -- status: placeholder -- intended enforcement point: gateway middleware -- intended domain: customer support routing -- intended technology: LiteLLM Proxy -- no implementation yet -- examples must use explicit authoritative state -- examples must not derive Context Compiler state from model output -- examples must remain meaningful with an adversarial stub +# LiteLLM Proxy (pre-call hook) + +This example shows how to run Context Compiler inside a LiteLLM proxy pre-call hook. +The hook applies fixed state rules before any upstream model call. + +Available hook files: + +- Basic replay-only hook: `context_compiler_precall_hook.py` +- Directive-drafter-enabled hook: `context_compiler_precall_hook_with_directive_drafter.py` + +## Requirements + +```shell +pip install "context-compiler[litellm_proxy]" +export OPENAI_API_KEY=... +``` + +`litellm_proxy` is intentionally separate from the general `integrations` +extra because this path targets proxy/gateway runtime use. + +For `context_compiler_precall_hook_with_directive_drafter.py`: + +```shell +pip install context-compiler-directive-drafter +``` + +## Quickstart (copy/paste) + +From the repo root: + +```shell +pip install "context-compiler[litellm_proxy]" +export OPENAI_API_KEY=... +litellm --config examples/integrations/litellm_proxy/config.example.yaml +``` + +`config.example.yaml` includes both OpenAI and Ollama model definitions. +Use the Ollama model entry for local testing without API credentials. + +## Run proxy + +Typical startup command (environment-sensitive): + +```shell +litellm --config config.example.yaml +``` + +Hook behavior and proxy startup were re-validated end-to-end with +`litellm==1.88.2`. + +Validated behaviors: + +- passthrough: upstream model called normally +- update: compiler state injected before upstream model call +- clarify: request blocked before upstream model call and surfaced as HTTP 400 + +The proxy runs on `http://localhost:4000` by default. +By default, `config.example.yaml` points to the basic replay-only hook. +To use the directive-drafter variant, switch the callback path in the config. + +## Make a request + +```python +from openai import OpenAI + +client = OpenAI( + api_key="anything", + base_url="http://localhost:4000", +) + +response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "prohibit peanuts"}], +) +``` + +Or with curl: + +```shell +curl http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer anything" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "prohibit peanuts"}] + }' +``` + +## Behavior + +- User messages are replayed through Context Compiler before the model call. +- If result is `clarify`, the proxy does not call the model and LiteLLM surfaces the clarification as an HTTP 400 response. +- If result is `passthrough`, the proxy forwards the request normally. +- If result is `update`, the proxy injects compiler state as a system message and then calls the model. + +Directive-drafter-enabled variant behavior: + +- Only the latest user transcript message is drafted for compiler replay input. +- Heuristic runs first; if no directive is found, LLM fallback is attempted. +- If `engine.has_pending_clarification()` is true, bypass directive drafting and pass raw input directly to `engine.step(...)`. +- Forwarded upstream request messages are not rewritten (except injected compiler system message). + +Optional env vars for directive-drafter fallback: + +```shell +export PREPROCESSOR_MODEL=openai/gpt-4o-mini +export PREPROCESSOR_PROMPT_PROFILE=default +``` + +`PREPROCESSOR_MODEL` is optional and defaults to `MODEL`. + +For heuristic-first usage, keep `PREPROCESSOR_PROMPT_PROFILE=default`. +Use `llama` only for LLM-only preprocessing with Llama-family models. + +## Note + +- The callback path in `config.example.yaml` must be importable by LiteLLM. + +## Troubleshooting + +- Callback import failures: verify the callback path configured in `config.example.yaml` is importable in the current LiteLLM environment. +- proxy starts but upstream calls fail: check `OPENAI_API_KEY` and upstream model/provider config in `config.example.yaml`. +- directive-drafter fallback issues: `PREPROCESSOR_MODEL` defaults to `MODEL`; set it explicitly only when using a separate fallback model. diff --git a/python/reference_integrations/litellm_proxy/config.example.yaml b/python/reference_integrations/litellm_proxy/config.example.yaml new file mode 100644 index 0000000..f0fff41 --- /dev/null +++ b/python/reference_integrations/litellm_proxy/config.example.yaml @@ -0,0 +1,19 @@ +model_list: + # `model_name` is the client-facing alias sent to the proxy. + # `litellm_params.model` is the upstream provider/model LiteLLM calls. + - model_name: gpt-4o-mini + litellm_params: + model: openai/gpt-4o-mini + api_key: os.environ/OPENAI_API_KEY + + - model_name: llama3.1 + litellm_params: + model: ollama/llama3.1:8b + api_base: http://localhost:11434 + +litellm_settings: + callbacks: + # Basic replay-only hook: + - context_compiler_precall_hook.proxy_handler_instance + # Preprocessor-enabled replay hook (use this instead of the basic hook): + # - context_compiler_precall_hook_with_directive_drafter.proxy_handler_instance diff --git a/python/reference_integrations/litellm_proxy/context_compiler_precall_hook.py b/python/reference_integrations/litellm_proxy/context_compiler_precall_hook.py new file mode 100644 index 0000000..1206c1d --- /dev/null +++ b/python/reference_integrations/litellm_proxy/context_compiler_precall_hook.py @@ -0,0 +1,135 @@ +"""Minimal LiteLLM Proxy pre-call hook example. + +Architecture: +- Replay user transcript through Context Compiler before any model call. +- If clarification is required, block upstream model call. +- Otherwise inject compiled state guidance into a system message. +""" + +import logging +from typing import Any + +try: + from litellm.integrations.custom_logger import CustomLogger +except ModuleNotFoundError: + # Keep this import path optional: CI/tests run without integration extras. + # A tiny fallback base class keeps module imports deterministic so coverage + # validates behavior instead of failing or silently skipping on missing litellm. + class CustomLogger: # type: ignore[no-redef] + pass + + +from context_compiler import ( + POLICY_PROHIBIT, + State, + Transcript, + compile_transcript, + get_policy_items, + get_premise_value, +) + +logger = logging.getLogger(__name__) + +_SUPPORTED_CALL_TYPES = { + "completion", + "acompletion", + "chat_completion", + "achat_completion", +} + + +def _render_compiled_state_contract(compiled_state: State) -> str: + prohibited = get_policy_items(compiled_state, POLICY_PROHIBIT) + premise = get_premise_value(compiled_state) + + lines: list[str] = ["The following constraints are authoritative."] + if prohibited: + items = ", ".join(prohibited) + lines.append(f"Never recommend or use prohibited items: {items}.") + if premise: + lines.append( + "When the answer depends on user preference/style, " + f"treat the current premise as: {premise}." + ) + lines.append("If the user message conflicts with these constraints, follow them exactly.") + + return "Host policy contract:\n" + "\n".join(f"- {line}" for line in lines) + + +def _extract_request_messages(data: dict[str, object]) -> list[dict[str, object]]: + raw_messages = data.get("messages") + if not isinstance(raw_messages, list): + return [] + return [msg for msg in raw_messages if isinstance(msg, dict)] + + +def _extract_text_content(content: object) -> str | None: + if isinstance(content, str): + return content + if isinstance(content, list): + text_parts: list[str] = [] + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") != "text": + continue + text = item.get("text") + if isinstance(text, str): + text_parts.append(text) + if text_parts: + return " ".join(text_parts) + return None + + +def _extract_user_transcript(messages: list[dict[str, object]]) -> Transcript: + transcript: Transcript = [] + for message in messages: + role = message.get("role") + content = message.get("content") + text_content = _extract_text_content(content) + if role == "user" and text_content is not None: + transcript.append({"role": "user", "content": text_content}) + return transcript + + +class ContextCompilerPreCallHook(CustomLogger): + async def async_pre_call_hook( + self, + user_api_key_dict: Any, + cache: Any, + data: dict[str, object], + call_type: str, + ) -> dict[str, object] | str: + del user_api_key_dict, cache + logger.debug("litellm_proxy: call_type=%s", call_type) + if call_type not in _SUPPORTED_CALL_TYPES: + return data + + request_messages = _extract_request_messages(data) + logger.debug("litellm_proxy: message_count=%d", len(request_messages)) + user_transcript = _extract_user_transcript(request_messages) + logger.debug("litellm_proxy: transcript_len=%d", len(user_transcript)) + replay_result = compile_transcript(user_transcript) + logger.debug("litellm_proxy: replay_kind=%s", replay_result["kind"]) + + if replay_result["kind"] == "confirm": + # Returning a string from this pre-call hook blocks the upstream + # LiteLLM model call under LiteLLM callback semantics. + logger.debug("litellm_proxy: blocking_on_confirm=true") + return replay_result["prompt_to_user"] or "Confirmation required." + + compiled_state = replay_result["state"] + # For long-running conversations, you can optionally compact transcripts by removing user inputs that were compiled into state. See Demo 6. # noqa: E501 + system_message: dict[str, object] = { + "role": "system", + "content": "You are a helpful assistant.\n" + + _render_compiled_state_contract(compiled_state), + } + # Prepend one compiler contract system message, then forward the original + # request messages unchanged. Existing system messages are preserved. + logger.debug("litellm_proxy: inject_system_message=true") + data["messages"] = [system_message, *request_messages] + return data + + +proxy_handler_instance = ContextCompilerPreCallHook() diff --git a/python/reference_integrations/litellm_proxy/context_compiler_precall_hook_with_directive_drafter.py b/python/reference_integrations/litellm_proxy/context_compiler_precall_hook_with_directive_drafter.py new file mode 100644 index 0000000..b7c2cc3 --- /dev/null +++ b/python/reference_integrations/litellm_proxy/context_compiler_precall_hook_with_directive_drafter.py @@ -0,0 +1,280 @@ +"""LiteLLM Proxy pre-call hook with optional directive drafter on latest user message. + +Architecture: +- Replay user transcript through Context Compiler before any model call. +- Preprocess only the latest user message for compiler replay input. +- If clarification is required, block upstream model call. +- Otherwise inject compiled state guidance into a system message. +""" + +import logging +import os +from collections.abc import Callable, Mapping, Sequence +from importlib import import_module +from importlib.resources import as_file, files +from importlib.resources.abc import Traversable +from typing import Any, cast + +try: + from litellm.integrations.custom_logger import CustomLogger +except ModuleNotFoundError: + # Keep this import path optional: CI/tests run without integration extras. + # A tiny fallback base class keeps module imports deterministic so coverage + # validates behavior instead of failing or silently skipping on missing litellm. + class CustomLogger: # type: ignore[no-redef] + pass + + +from context_compiler import ( + POLICY_PROHIBIT, + State, + Transcript, + compile_transcript, + get_policy_items, + get_premise_value, +) +from context_compiler_directive_drafter import ( + PREPROCESS_OUTCOME_DIRECTIVE, + parse_preprocessor_output, + preprocess_heuristic, + render_prompt, +) + +logger = logging.getLogger(__name__) + +_SUPPORTED_CALL_TYPES = { + "completion", + "acompletion", + "chat_completion", + "achat_completion", +} + +_PROMPTS_DIR = files("context_compiler_directive_drafter").joinpath("prompts") + + +def _render_compiled_state_contract(compiled_state: State) -> str: + prohibited = get_policy_items(compiled_state, POLICY_PROHIBIT) + premise = get_premise_value(compiled_state) + + lines: list[str] = ["The following constraints are authoritative."] + if prohibited: + items = ", ".join(prohibited) + lines.append(f"Never recommend or use prohibited items: {items}.") + if premise: + lines.append( + "When the answer depends on user preference/style, " + f"treat the current premise as: {premise}." + ) + lines.append("If the user message conflicts with these constraints, follow them exactly.") + + return "Host policy contract:\n" + "\n".join(f"- {line}" for line in lines) + + +def _extract_request_messages(data: dict[str, object]) -> list[dict[str, object]]: + raw_messages = data.get("messages") + if not isinstance(raw_messages, list): + return [] + return [msg for msg in raw_messages if isinstance(msg, dict)] + + +def _extract_text_content(content: object) -> str | None: + if isinstance(content, str): + return content + if isinstance(content, list): + text_parts: list[str] = [] + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") != "text": + continue + text = item.get("text") + if isinstance(text, str): + text_parts.append(text) + if text_parts: + return " ".join(text_parts) + return None + + +def _extract_user_transcript(messages: list[dict[str, object]]) -> Transcript: + transcript: Transcript = [] + for message in messages: + role = message.get("role") + content = message.get("content") + text_content = _extract_text_content(content) + if role == "user" and text_content is not None: + transcript.append({"role": "user", "content": text_content}) + return transcript + + +def _extract_response_content(response: object) -> str | None: + if isinstance(response, Mapping): + choices = response.get("choices") + if isinstance(choices, Sequence) and choices: + first = choices[0] + if isinstance(first, Mapping): + message = first.get("message") + if isinstance(message, Mapping): + content = message.get("content") + if isinstance(content, str): + return content + + choices_attr = getattr(response, "choices", None) + if isinstance(choices_attr, Sequence) and choices_attr: + first = choices_attr[0] + message_attr = getattr(first, "message", None) + content_attr = getattr(message_attr, "content", None) + if isinstance(content_attr, str): + return content_attr + + return None + + +def _prompt_file_path() -> Traversable: + profile = os.getenv("PREPROCESSOR_PROMPT_PROFILE", "default").strip().lower() + if profile == "llama": + return _PROMPTS_DIR.joinpath("llama.txt") + return _PROMPTS_DIR.joinpath("default.txt") + + +def _get_litellm_completion() -> Callable[..., object]: + litellm_module = import_module("litellm") + return cast(Callable[..., object], litellm_module.completion) + + +def _llm_fallback_preprocess(message: str, state: State) -> str | None: + with as_file(_prompt_file_path()) as prompt_path: + prompt = render_prompt(prompt_path, state) + if prompt is None: + return None + + preprocessor_model = os.getenv("PREPROCESSOR_MODEL", "").strip() + if not preprocessor_model: + preprocessor_model = os.getenv("MODEL", "").strip() + if not preprocessor_model: + return None + + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + return None + + try: + completion = _get_litellm_completion() + except ModuleNotFoundError: + return None + + kwargs: dict[str, object] = { + "model": preprocessor_model, + "messages": [ + {"role": "system", "content": prompt}, + {"role": "user", "content": message}, + ], + "api_key": api_key, + "temperature": 0, + } + api_base = os.getenv("OPENAI_BASE_URL") + if api_base: + kwargs["api_base"] = api_base + + try: + response = completion(**kwargs) + raw_output = _extract_response_content(response) + except Exception: + return None + + parsed = parse_preprocessor_output(raw_output, source_input=message) + if parsed is None: + return None + return parsed + + +def _state_before_last_message(user_transcript: Transcript) -> State | None: + if not user_transcript: + return None + prefix = user_transcript[:-1] + replay = compile_transcript(prefix) + if replay["kind"] != "state": + return None + return replay["state"] + + +def _preprocess_last_user_message(message: str, state: State | None) -> str | None: + try: + heuristic_result = preprocess_heuristic(message) + if ( + heuristic_result["outcome"] == PREPROCESS_OUTCOME_DIRECTIVE + and heuristic_result["directive"] + ): + parsed = parse_preprocessor_output(heuristic_result["directive"]) + if parsed is not None: + return parsed + except Exception: + logger.debug("litellm_proxy: heuristic_exception", exc_info=True) + + if state is None: + return None + + try: + return _llm_fallback_preprocess(message, state) + except Exception: + logger.debug("litellm_proxy: fallback_exception", exc_info=True) + return None + + +class ContextCompilerPreCallHookWithPreprocessor(CustomLogger): + async def async_pre_call_hook( + self, + user_api_key_dict: Any, + cache: Any, + data: dict[str, object], + call_type: str, + ) -> dict[str, object] | str: + del user_api_key_dict, cache + logger.debug("litellm_proxy: call_type=%s", call_type) + if call_type not in _SUPPORTED_CALL_TYPES: + return data + + request_messages = _extract_request_messages(data) + logger.debug("litellm_proxy: message_count=%d", len(request_messages)) + + user_transcript = _extract_user_transcript(request_messages) + logger.debug("litellm_proxy: transcript_len=%d", len(user_transcript)) + + transcript_for_replay = user_transcript + replaced_last_user_message = False + preprocessd: str | None = None + + if user_transcript: + last_user_content = cast(str, user_transcript[-1]["content"]) + prior_state = _state_before_last_message(user_transcript) + preprocessd = _preprocess_last_user_message(last_user_content, prior_state) + logger.debug("litellm_proxy: preprocessd=%r", preprocessd) + if preprocessd: + transcript_for_replay = [*user_transcript] + transcript_for_replay[-1] = {"role": "user", "content": preprocessd} + replaced_last_user_message = True + + logger.debug("litellm_proxy: replaced_last_user_message=%s", replaced_last_user_message) + + replay_result = compile_transcript(transcript_for_replay) + logger.debug("litellm_proxy: replay_kind=%s", replay_result["kind"]) + + if replay_result["kind"] == "confirm": + # Returning a string from this pre-call hook blocks the upstream + # LiteLLM model call under LiteLLM callback semantics. + logger.debug("litellm_proxy: blocking_on_confirm=true") + return replay_result["prompt_to_user"] or "Confirmation required." + + compiled_state = replay_result["state"] + system_message: dict[str, object] = { + "role": "system", + "content": "You are a helpful assistant.\n" + + _render_compiled_state_contract(compiled_state), + } + logger.debug("litellm_proxy: inject_system_message=true") + # Preserve original request messages; only compiler replay input uses + # the preprocessed latest user message when available. + data["messages"] = [system_message, *request_messages] + return data + + +proxy_handler_instance = ContextCompilerPreCallHookWithPreprocessor() diff --git a/python/reference_integrations/openwebui_pipe/README.md b/python/reference_integrations/openwebui_pipe/README.md index 01ab135..1142c4a 100644 --- a/python/reference_integrations/openwebui_pipe/README.md +++ b/python/reference_integrations/openwebui_pipe/README.md @@ -1,10 +1,151 @@ -# Open WebUI pipe reference integration - -- status: placeholder -- intended enforcement point: request construction / context assembly -- intended domain: writing assistant -- intended technology: Open WebUI -- no implementation yet -- examples must use explicit authoritative state -- examples must not derive Context Compiler state from model output -- examples must remain meaningful with an adversarial stub +# Open WebUI Pipe Integration + +Examples of Open WebUI Pipe Functions that use Context Compiler. + +Tested target: Open WebUI `v0.8.12` (latest at time of testing). +Validated at runtime on stock Docker Open WebUI with a real backend model provider. + +Compatibility note: OpenWebUI `0.9.x` changed `Users.get_user_by_id` to async. +These examples support both sync (`0.8.x`) and async (`0.9.x`) user lookup. + +## Files + +- `open_webui_pipe.py`: basic integration, no directive-drafter layer (recommended/default). +- `open_webui_pipe_with_directive_drafter.py`: optional/experimental directive-drafter layer (rule-based check first, then optional model fallback) before `engine.step(...)`. + +## Setup + +The minimal pipe path below is the easiest first-run flow and was runtime-validated in Docker via API flow with a real backend model. + +1. Import `open_webui_pipe.py` (recommended/default) as a Function by URL. +2. Open WebUI installs `context-compiler>=0.7.4` from the function frontmatter requirements. +3. Enable the function. +4. Set `BASE_MODEL_ID` to a valid Open WebUI model id (required). +5. Select the pipe model in chat. + +Open WebUI is a separate runtime and must already be installed/configured separately. +Open WebUI also needs at least one real backend model/provider configured (for example Ollama or OpenAI) so `BASE_MODEL_ID` resolves to an actual model. +Note: The `PROVIDER` environment contract used in LiteLLM examples/demos does not apply to OpenWebUI. OpenWebUI manages providers via its own connection settings and model IDs. + +Checkpoint continuation in these examples requires `context-compiler>=0.7.4`. + +### Model configuration + +- Open: `http://localhost:3000/admin/functions` +- Verify `BASE_MODEL_ID` matches an existing Open WebUI model id exactly. +- Example: + - `BASE_MODEL_ID = llama3.1:8b` +- Model ids are configured in: `Admin Panel → Settings → Models` + +If using `open_webui_pipe_with_directive_drafter.py`: +- Install directive-drafter support in the Open WebUI environment: + - `pip install context-compiler-directive-drafter` +- Open WebUI executes copied functions from a temp/cached location, so + directive-drafter imports/resources must come from the installed package (not + repo-relative paths). +- Set `PREPROCESSOR_PROMPT_PROFILE` to `default` for heuristic-first usage. +- Use `llama` only for LLM-only preprocessing with Llama-family models. +- Prompt files are loaded from the installed package prompts (`default`/`llama` profiles). +- Optional: set `PREPROCESSOR_MODEL_ID` to route fallback precompilation through + a separate model. If unset, fallback uses `BASE_MODEL_ID`. +- Fallback routing is Open WebUI-native (no LiteLLM dependency for this pipe). +- The heuristic directive drafter is intentionally conservative and high-precision, and + may abstain on mixed-prose natural language (for example, `i think we should + use docker`). In those cases, behavior may remain passthrough unless fallback + precompilation returns a validated canonical directive. +- If you configure invalid model ids, the pipe returns explicit runtime errors: + - `BASE_MODEL_ID` not found in Open WebUI models + - `PREPROCESSOR_MODEL_ID` not found in Open WebUI models + +### Docker/manual install fallback + +If frontmatter dependency installs are disabled, offline, or unavailable: + +1. Open a shell in the Open WebUI container: + - `docker exec -it sh` +2. Install the package manually: + - Minimal pipe: `pip install "context-compiler>=0.7.4"` + - Directive-drafter pipe: `pip install "context-compiler>=0.7.4" context-compiler-directive-drafter` +3. Import and enable the function in Open WebUI, then configure valves. + +### Finding valid model ids + +Use the Open WebUI model picker/list to copy exact model ids for `BASE_MODEL_ID` +(and optional `PREPROCESSOR_MODEL_ID` for the directive-drafter pipe). + +## Limitations + +- No durable external persistence (checkpoint continuation is in-process only). +- No multi-worker or cross-process guarantees. +- No Redis/DB/external storage. +- No Filters or Pipelines. +- No production hardening. +- Tied to Open WebUI internal helper/import paths by version. + +## Manual Validation + +Validate these behaviors: +- `clarify` short-circuits the LLM call. +- `passthrough` forwards input without state injection. +- `update` forwards with compiler state (`[[cc_state]]`) added to the request. +- chat isolation works with real chat ids. +- state is lost after restart (no external persistence). +- non-text input is bypassed. + +Note: In the OpenWebUI example pipes, recognized directive-only `update` +decisions return fixed local acknowledgments and do not call the +downstream LLM. +Both pipes support an exact local inspection command: `show state`. +When the latest user message is exactly `show state` (case-insensitive after trim), +the pipe returns current compiler state locally and does not call the downstream model. +When trace is enabled, responses include concise evidence of decision kind, +active state, downstream LLM call/no-call, and whether state was injected. + +Decision flow in both pipes: +- `passthrough`: call the downstream model with normal input. +- `clarify`: show `prompt_to_user`; do not change saved state. +- `update`: state changed; render local acknowledgment for directive-only input, or call downstream model with updated state injected. + +For the directive-drafter pipe, if `engine.has_pending_clarification()` is true, bypass directive drafting and pass raw input directly to `engine.step(...)`. + +## Behavioral comparisons + +**Case 1** + +- prompt(s): `clear state` → `change premise to formal tone` +- base model: “To adjust the tone… provide the original content…” +- basic pipe: `No premise exists yet. Use 'set premise ...' first.` +- directive-drafter pipe: `No premise exists yet. Use 'set premise ...' first.` +- why this matters: lifecycle rule is enforced in a fixed, repeatable way; base model drifts into generic rewriting help. + +**Case 2** + +- prompt(s): `clear state` → `use docker` → `prohibit docker` +- base model: generic Docker/prohibition guidance text +- basic pipe: `'docker' is already in use. Only one policy per item is allowed. Use 'reset policies' to change it.` +- directive-drafter pipe: same conflict clarify +- why this matters: the app asks before applying a conflicting change. + +**Case 3** + +- prompt(s): `clear state` → `use podman instead of docker` +- base model: generic “how to switch to Podman” tutorial +- basic pipe: `No exact policy found for "docker". Replacement requires an exact policy match...` +- directive-drafter pipe: same replacement clarify +- why this matters: the app only replaces a policy when the old item already exists. + +**Case 4** + +- prompt(s): `clear state` → `set premise to concise replies` +- base model: accepts conversational style phrasing +- basic pipe: `Did you mean 'set premise concise replies'?` +- directive-drafter pipe: same clarify (near-miss is not rewritten) +- why this matters: near-miss text is not silently rewritten. + +**Case 5** + +- prompt(s): `clear state` → `change premise concise replies` +- base model: generic “please clarify changes” response +- basic pipe: `Did you mean 'change premise to concise replies'?` +- directive-drafter pipe: same clarify (near-miss is passed through unchanged) +- why this matters: the app waits for explicit, valid directive text before changing state. diff --git a/python/reference_integrations/openwebui_pipe/open_webui_pipe.py b/python/reference_integrations/openwebui_pipe/open_webui_pipe.py new file mode 100644 index 0000000..f3c373a --- /dev/null +++ b/python/reference_integrations/openwebui_pipe/open_webui_pipe.py @@ -0,0 +1,693 @@ +""" +title: Context Compiler Pipe +author: rlippmann +author_url: https://github.com/rlippmann/context-compiler +funding_url: https://github.com/rlippmann/context-compiler +version: 0.9.3 +requirements: context-compiler>=0.7.4 + +Minimal Open WebUI Pipe integration for Context Compiler. + +This integration demonstrates mapping Context Compiler `Decision` output into +Open WebUI request flow. + +Scope is intentionally limited: +- Single Pipe Function for Open WebUI 0.8.x and 0.9.x. +- In-memory per-process engine map keyed by chat key. +- No persistence, no multi-worker coordination, no external storage. +""" + +import inspect +import json +import logging +import re +from collections.abc import AsyncIterator +from typing import Any, cast + +from fastapi import Request # type: ignore[import-not-found] +from open_webui.models.users import Users # type: ignore[import-not-found] +from open_webui.utils.chat import generate_chat_completion # type: ignore[import-not-found] + +try: + from pydantic import BaseModel, Field +except ModuleNotFoundError: + # Keep this import optional: CI/tests run without integration extras. + # These lightweight fallbacks keep import-time behavior deterministic so + # coverage exercises the pipe module without pydantic installed. + class BaseModel: # type: ignore[no-redef] + def __init__(self, **kwargs: object) -> None: + for key, value in kwargs.items(): + setattr(self, key, value) + + def Field(*, default: Any, description: str = "") -> Any: # type: ignore[no-redef] + del description + return default + + +from context_compiler import ( + DECISION_CLARIFY, + DECISION_PASSTHROUGH, + DECISION_UPDATE, + POLICY_PROHIBIT, + POLICY_USE, + State, + create_engine, + get_clarify_prompt, + get_decision_state, + get_policy_items, + get_premise_value, + is_clarify, + is_passthrough, + is_update, +) +from context_compiler.engine import Engine +from context_compiler.observability import build_compact_trace_text + +logger = logging.getLogger(__name__) + +_CC_MARKER = "[[cc_state]]" +_ENGINES_BY_CHAT_KEY: dict[str, Engine] = {} +# Example-only in-memory checkpoint store. +# This keeps continuation state only for the current process lifetime. +# Real deployments should persist checkpoints externally (DB/Redis/etc.), +# or restart continuity for pending flows will be lost. +_CHECKPOINTS_BY_CHAT_KEY: dict[str, str] = {} + + +def _resolve_chat_key( + user: dict[str, Any], + chat_id: str | None, + metadata: dict[str, Any] | None, +) -> str: + """Resolve chat key from reserved args with a minimal fallback. + + Resolution order: + 1. ``__chat_id__`` + 2. ``__metadata__["chat_id"]`` + 3. ``no-chat-id:`` + + The fallback key is a degraded convenience for this minimal integration and + is not a strong chat-isolation guarantee. + """ + if chat_id: + return chat_id + if isinstance(metadata, dict): + metadata_chat_id = metadata.get("chat_id") + if isinstance(metadata_chat_id, str) and metadata_chat_id: + return metadata_chat_id + user_id = str(user["id"]) + return f"no-chat-id:{user_id}" + + +def _extract_latest_user_text(messages: list[dict[str, Any]]) -> str | None: + """Return latest plain-text user content, scanning from the end. + + Uses the last message with ``role == "user"``. Only plain string content is + eligible for compilation. Non-text or missing-user cases return ``None`` so + the caller can bypass compiler behavior. + """ + for message in reversed(messages): + if message.get("role") != "user": + continue + content = message.get("content") + if isinstance(content, str): + return content + return None + return None + + +def _render_compiler_state_block(state: State) -> str: + """Render deterministic compiler-owned state block text. + + The first line is ``[[cc_state]]``. Optional lines follow for ``Premise``, + ``Use``, and ``Prohibit``. Policy items are rendered alphabetically, and + identical state must produce identical output bytes. + """ + lines: list[str] = [_CC_MARKER] + + premise = get_premise_value(state) + if premise is not None: + lines.append(f"Premise: {premise}") + + use_items = sorted(get_policy_items(state, POLICY_USE)) + if use_items: + lines.append("Use: " + ", ".join(use_items)) + + prohibit_items = sorted(get_policy_items(state, POLICY_PROHIBIT)) + if prohibit_items: + lines.append("Prohibit: " + ", ".join(prohibit_items)) + + return "\n".join(lines) + + +def _render_show_state_summary(engine: Engine) -> str: + premise = get_premise_value(engine.state) + use_items = sorted(get_policy_items(engine.state, POLICY_USE)) + prohibit_items = sorted(get_policy_items(engine.state, POLICY_PROHIBIT)) + pending = engine.has_pending_clarification() + + use_text = ", ".join(use_items) if use_items else "none" + prohibit_text = ", ".join(prohibit_items) if prohibit_items else "none" + premise_text = premise if premise is not None else "none" + pending_text = "yes" if pending else "no" + + return ( + f"Premise: {premise_text}\n" + f"Use: {use_text}\n" + f"Prohibit: {prohibit_text}\n" + f"Pending clarification: {pending_text}" + ) + + +def _replace_compiler_system_message( + messages: list[dict[str, Any]], + rendered_state_block: str, +) -> list[dict[str, Any]]: + """Replace compiler-owned state messages while preserving other order. + + Compiler-owned messages are identified by ``[[cc_state]]`` prefix. Existing + compiler-owned system messages are removed, and one fresh compiler-owned + system message is inserted after the last remaining system message, else at + index ``0``. Relative order of non-compiler messages is preserved. + + Invariant: exactly one compiler-owned state message exists afterward. + """ + filtered_messages: list[dict[str, Any]] = [] + last_system_index = -1 + + for message in messages: + role = message.get("role") + content = message.get("content") + if role == "system" and isinstance(content, str) and content.startswith(_CC_MARKER): + continue + + filtered_messages.append(message) + if role == "system": + last_system_index = len(filtered_messages) - 1 + + insert_at = last_system_index + 1 if last_system_index >= 0 else 0 + compiler_message: dict[str, Any] = {"role": "system", "content": rendered_state_block} + return [ + *filtered_messages[:insert_at], + compiler_message, + *filtered_messages[insert_at:], + ] + + +def _normalize_state(value: object) -> State: + if isinstance(value, dict): + return cast(State, value) + return {"premise": None, "policies": {}, "version": 2} + + +def _has_non_empty_authoritative_state(state: State) -> bool: + if get_premise_value(state) is not None: + return True + return bool(get_policy_items(state, POLICY_USE) or get_policy_items(state, POLICY_PROHIBIT)) + + +def _build_compact_trace_text( + *, + decision: object, + state_before: object, + state_after: object, + llm_called: bool, + state_injected: str, +) -> str: + return build_compact_trace_text( + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=llm_called, + state_injected=state_injected, + ) + + +def _strip_trace_block_from_text(content: str) -> str: + marker = "Context Compiler trace" + index = content.find(marker) + if index < 0: + return content + return content[:index].rstrip() + + +def _strip_trace_blocks_from_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + cleaned: list[dict[str, Any]] = [] + for message in messages: + msg = dict(message) + content = msg.get("content") + if isinstance(content, str): + msg["content"] = _strip_trace_block_from_text(content) + cleaned.append(msg) + return cleaned + + +def _build_forward_messages( + raw_messages: object, + *, + state: State | None = None, +) -> list[dict[str, Any]]: + """Build forwarded messages with trace stripping and optional state injection.""" + messages = ( + _strip_trace_blocks_from_messages([msg for msg in raw_messages if isinstance(msg, dict)]) + if isinstance(raw_messages, list) + else [] + ) + if state is not None and _has_non_empty_authoritative_state(state): + return _replace_compiler_system_message( + messages, + _render_compiler_state_block(state), + ) + return messages + + +def _strip_existing_trace_from_chunk(chunk: object) -> object: + if isinstance(chunk, str): + return _strip_trace_block_from_text(chunk) + if isinstance(chunk, bytes): + decoded = chunk.decode("utf-8", errors="ignore") + cleaned = _strip_trace_block_from_text(decoded) + return cleaned.encode("utf-8") + return chunk + + +def _render_item_label(value: str) -> str: + return re.sub(r"\s+", " ", value).strip().lower() + + +def _near_miss_directive_clarify(value: str) -> str | None: + normalized = re.sub(r"\s+", " ", value.strip()) + lower = normalized.lower() + + if lower in {"reset premise", "reset premises", "clear premises"}: + return "Unknown directive.\nUse 'clear premise' or 'reset policies'." + if lower.startswith("set premise to "): + return "Invalid premise syntax.\nUse 'set premise '." + if lower.startswith("change premise ") and not lower.startswith("change premise to "): + return "Invalid premise syntax.\nUse 'change premise to '." + return None + + +def _summarize_update_from_input(user_input: str) -> str: + normalized = re.sub(r"\s+", " ", user_input.strip()) + lower = normalized.lower() + + if lower == "clear state": + return "State cleared." + if lower == "clear premise": + return "Premise cleared." + if lower == "reset policies": + return "Policies reset." + + replacement_match = re.match( + r"^use\s+(.+?)\s+instead\s+of\s+(.+)$", normalized, flags=re.IGNORECASE + ) + if replacement_match is not None: + item = _render_item_label(replacement_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + use_match = re.match(r"^use\s+(.+)$", normalized, flags=re.IGNORECASE) + if use_match is not None: + item = _render_item_label(use_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + prohibit_match = re.match(r"^prohibit\s+(.+)$", normalized, flags=re.IGNORECASE) + if prohibit_match is not None: + item = _render_item_label(prohibit_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Prohibit {item}." + + remove_policy_match = re.match(r"^remove\s+policy\s+(.+)$", normalized, flags=re.IGNORECASE) + if remove_policy_match is not None: + item = _render_item_label(remove_policy_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Removed policy {item}." + + return "State updated." + + +def _is_administrative_update_input(user_input: str) -> bool: + normalized = re.sub(r"\s+", " ", user_input.strip()).lower() + return ( + normalized == "clear state" + or normalized == "clear premise" + or normalized == "reset policies" + or normalized.startswith("remove policy ") + ) + + +class Pipe: + """Map Context Compiler decisions into Open WebUI pipe behavior. + + - ``clarify`` returns plain text and skips model forwarding. + - ``passthrough`` forwards with minimal mutation. + - ``update`` returns deterministic local acknowledgement (no model call). + """ + + class Valves(BaseModel): + BASE_MODEL_ID: str = Field( + default="", + description=( + "Required Open WebUI model id used for forwarding. Must exactly match a " + "configured model id in Open WebUI (not arbitrary text), for example: " + "llama3.1:8b." + ), + ) + SHOW_CONTEXT_COMPILER_TRACE: bool = Field( + default=False, + description="Include concise Context Compiler trace text in responses.", + ) + + def __init__(self) -> None: + self.valves = self.Valves() + + def _is_model_not_found_text(self, value: object) -> bool: + if not isinstance(value, str): + return False + return "model not found" in value.lower() + + def _contains_model_not_found(self, value: object) -> bool: + if self._is_model_not_found_text(value): + return True + if isinstance(value, dict): + return any(self._contains_model_not_found(v) for v in value.values()) + if isinstance(value, list): + return any(self._contains_model_not_found(v) for v in value) + return False + + def _normalize_forward_error(self, response: Any) -> str | None: + if self._contains_model_not_found(response): + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID is invalid or not " + "configured in Open WebUI. Configure a valid model id in " + "Admin Panel → Settings → Models." + ) + return None + + def _normalize_forward_exception(self, exc: Exception) -> str | None: + detail = getattr(exc, "detail", None) + if self._contains_model_not_found(detail) or self._contains_model_not_found(str(exc)): + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID is invalid or not " + "configured in Open WebUI. Configure a valid model id in " + "Admin Panel → Settings → Models." + ) + return None + + def _trace_enabled(self) -> bool: + return bool(getattr(self.valves, "SHOW_CONTEXT_COMPILER_TRACE", False)) + + def _append_trace_to_response(self, response: Any, trace_text: str) -> Any: + body_iterator = getattr(response, "body_iterator", None) + if body_iterator is not None and callable(getattr(body_iterator, "__aiter__", None)): + response.body_iterator = self._append_trace_to_stream( + cast(AsyncIterator[object], body_iterator), trace_text + ) + return response + aiter = getattr(response, "__aiter__", None) + if callable(aiter): + return self._append_trace_to_stream(cast(AsyncIterator[object], response), trace_text) + if isinstance(response, str): + cleaned = _strip_trace_block_from_text(response) + return f"{cleaned}\n\n{trace_text}" + if isinstance(response, dict): + choices = response.get("choices") + if isinstance(choices, list) and choices: + first = choices[0] + if isinstance(first, dict): + message = first.get("message") + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + cleaned = _strip_trace_block_from_text(content) + message["content"] = f"{cleaned}\n\n{trace_text}" + return response + choices_attr = getattr(response, "choices", None) + if isinstance(choices_attr, list) and choices_attr: + first_choice = choices_attr[0] + message_attr = getattr(first_choice, "message", None) + content_attr = getattr(message_attr, "content", None) + if message_attr is not None and isinstance(content_attr, str): + cleaned = _strip_trace_block_from_text(content_attr) + message_attr.content = f"{cleaned}\n\n{trace_text}" + return response + return response + + def _append_trace_to_stream( + self, stream: AsyncIterator[object], trace_text: str + ) -> AsyncIterator[object]: + async def _wrapped() -> AsyncIterator[object]: + chunk_type: type[str] | type[bytes] | None = None + saw_done = False + trace_json = json.dumps({"choices": [{"delta": {"content": f"\n\n{trace_text}"}}]}) + trace_event = f"data: {trace_json}\n\n" + + def _matches_done(value: str) -> bool: + normalized = value.strip() + return normalized == "data: [DONE]" or normalized == "[DONE]" + + async for chunk in stream: + if chunk_type is None: + if isinstance(chunk, bytes): + chunk_type = bytes + elif isinstance(chunk, str): + chunk_type = str + if isinstance(chunk, bytes): + decoded = chunk.decode("utf-8", errors="ignore") + if _matches_done(decoded): + saw_done = True + yield trace_event.encode("utf-8") + yield chunk + continue + elif isinstance(chunk, str) and _matches_done(chunk): + saw_done = True + yield trace_event + yield chunk + continue + yield _strip_existing_trace_from_chunk(chunk) + if saw_done: + return + suffix = f"\n\n{trace_text}" + if chunk_type is bytes: + yield suffix.encode("utf-8") + else: + yield suffix + + return _wrapped() + + def _with_trace( + self, + response: Any, + *, + original_input: str, + compiler_input: str, + decision: object, + state_before: object, + state_after: object, + llm_called: bool, + preprocessor_output: str | None = None, + state_injected: str = "no", + ) -> Any: + if not self._trace_enabled(): + return response + del original_input, compiler_input, preprocessor_output + trace_text = _build_compact_trace_text( + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=llm_called, + state_injected=state_injected, + ) + return self._append_trace_to_response(response, trace_text) + + async def _forward_passthrough( + self, + body: dict[str, Any], + user_payload: dict[str, Any], + request: Request, + *, + state: State | None = None, + ) -> Any: + """Forward with model override and optional compiler-owned state injection.""" + payload = {**body} + payload["model"] = self.valves.BASE_MODEL_ID + payload["messages"] = _build_forward_messages(body.get("messages"), state=state) + user = Users.get_user_by_id(user_payload["id"]) + if inspect.isawaitable(user): + user = await user + try: + response = await generate_chat_completion(request, payload, user) + except Exception as exc: + normalized_exception = self._normalize_forward_exception(exc) + if normalized_exception is not None: + return normalized_exception + raise + normalized_error = self._normalize_forward_error(response) + if normalized_error is not None: + return normalized_error + return response + + async def _forward_update( + self, + body: dict[str, Any], + user_payload: dict[str, Any], + request: Request, + state: State, + ) -> Any: + """Forward with one compiler-owned state message based on current state. + + The body is shallow-copied, ``model`` is overridden, and exactly one + compiler-owned message is inserted/replaced before forwarding. + """ + payload = {**body} + payload["model"] = self.valves.BASE_MODEL_ID + + payload["messages"] = _build_forward_messages(body.get("messages"), state=state) + + user = Users.get_user_by_id(user_payload["id"]) + if inspect.isawaitable(user): + user = await user + try: + response = await generate_chat_completion(request, payload, user) + except Exception as exc: + normalized_exception = self._normalize_forward_exception(exc) + if normalized_exception is not None: + return normalized_exception + raise + normalized_error = self._normalize_forward_error(response) + if normalized_error is not None: + return normalized_error + return response + + async def pipe( + self, + body: dict[str, Any], + __user__: dict[str, Any], + __request__: Request, + __chat_id__: str | None = None, + __metadata__: dict[str, Any] | None = None, + ) -> Any: + """Run minimal host flow around compiler decisions. + + Flow: + - Extract latest user text. + - Bypass compiler for non-text or missing-user turns. + - Resolve chat key and get/create per-chat engine. + - Call ``engine.step(...)``. + - Map ``clarify`` / ``passthrough`` / ``update`` outcomes. + """ + raw_messages = body.get("messages") + messages = ( + [msg for msg in raw_messages if isinstance(msg, dict)] + if isinstance(raw_messages, list) + else [] + ) + base_model_id = self.valves.BASE_MODEL_ID.strip() + current_model_id = str(body.get("model", "")).strip() + if not base_model_id: + return "Context Compiler pipe misconfigured: BASE_MODEL_ID is required." + if current_model_id and base_model_id == current_model_id: + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID must not match " + "the selected pipe model id to avoid recursive routing." + ) + + latest_user_text = _extract_latest_user_text(messages) + logger.debug("pipe: user_input_found=%s", latest_user_text is not None) + + if latest_user_text is None: + return await self._forward_passthrough(body, __user__, __request__) + + chat_key = _resolve_chat_key(__user__, __chat_id__, __metadata__) + engine = _ENGINES_BY_CHAT_KEY.get(chat_key) + if engine is None: + engine = create_engine() + checkpoint = _CHECKPOINTS_BY_CHAT_KEY.get(chat_key) + if checkpoint is not None: + engine.import_checkpoint_json(checkpoint) + _ENGINES_BY_CHAT_KEY[chat_key] = engine + + if latest_user_text.strip().lower() == "show state": + return _render_show_state_summary(engine) + + state_before = engine.state + logger.debug("pipe: engine_input=%r", latest_user_text) + decision = engine.step(latest_user_text) + if is_clarify(decision): + kind = DECISION_CLARIFY + elif is_update(decision): + kind = DECISION_UPDATE + else: + kind = DECISION_PASSTHROUGH + logger.debug("pipe: decision=%s", kind) + near_miss_prompt = _near_miss_directive_clarify(latest_user_text) + state_after = get_decision_state(decision) + if state_after is None: + state_after = engine.state + + if is_clarify(decision): + _CHECKPOINTS_BY_CHAT_KEY[chat_key] = engine.export_checkpoint_json() + return self._with_trace( + near_miss_prompt or get_clarify_prompt(decision) or "", + original_input=latest_user_text, + compiler_input=latest_user_text, + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=False, + ) + if near_miss_prompt is not None and is_passthrough(decision): + return self._with_trace( + near_miss_prompt, + original_input=latest_user_text, + compiler_input=latest_user_text, + decision={"kind": DECISION_CLARIFY, "prompt_to_user": near_miss_prompt}, + state_before=state_before, + state_after=state_after, + llm_called=False, + ) + if is_passthrough(decision): + compiled_state = _normalize_state(state_after) + state_injected = "yes" if _has_non_empty_authoritative_state(compiled_state) else "no" + response = await self._forward_passthrough( + body, __user__, __request__, state=compiled_state + ) + return self._with_trace( + response, + original_input=latest_user_text, + compiler_input=latest_user_text, + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=True, + state_injected=state_injected, + ) + if is_update(decision): + _CHECKPOINTS_BY_CHAT_KEY[chat_key] = engine.export_checkpoint_json() + return self._with_trace( + _summarize_update_from_input(latest_user_text), + original_input=latest_user_text, + compiler_input=latest_user_text, + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=False, + ) + + compiled_state = _normalize_state(state_after) + state_injected = "yes" if _has_non_empty_authoritative_state(compiled_state) else "no" + response = await self._forward_passthrough( + body, __user__, __request__, state=compiled_state + ) + return self._with_trace( + response, + original_input=latest_user_text, + compiler_input=latest_user_text, + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=True, + state_injected=state_injected, + ) diff --git a/python/reference_integrations/openwebui_pipe/open_webui_pipe_with_directive_drafter.py b/python/reference_integrations/openwebui_pipe/open_webui_pipe_with_directive_drafter.py new file mode 100644 index 0000000..090cd93 --- /dev/null +++ b/python/reference_integrations/openwebui_pipe/open_webui_pipe_with_directive_drafter.py @@ -0,0 +1,984 @@ +""" +title: Context Compiler Pipe (Directive Drafter) +author: rlippmann +author_url: https://github.com/rlippmann/context-compiler +funding_url: https://github.com/rlippmann/context-compiler +version: 0.9.3 +requirements: context-compiler>=0.7.4, context-compiler-directive-drafter>=0.1.0 + +Open WebUI integration with Context Compiler directive drafter. + +This example extends `open_webui_pipe.py` by inserting a directive-drafting step: + +1. Run heuristic directive drafter (fast, high-precision cases) +2. Fall back to Open WebUI-native model completion when needed +3. Pass resulting directive (or original input) to `engine.step(...)` + +Core decision handling remains the same as the base integration. +""" + +import inspect +import json +import logging +import re +from collections.abc import AsyncIterator +from importlib.resources import as_file, files +from importlib.resources.abc import Traversable +from typing import Any, Literal, cast + +from fastapi import Request # type: ignore[import-not-found] +from open_webui.models.users import Users # type: ignore[import-not-found] +from open_webui.utils.chat import generate_chat_completion # type: ignore[import-not-found] +from open_webui.utils.models import get_all_models # type: ignore[import-not-found] + +try: + from pydantic import BaseModel, Field +except ModuleNotFoundError: + # Keep this import optional: CI/tests run without integration extras. + # These lightweight fallbacks keep import-time behavior deterministic so + # coverage exercises the pipe module without pydantic installed. + class BaseModel: # type: ignore[no-redef] + def __init__(self, **kwargs: object) -> None: + for key, value in kwargs.items(): + setattr(self, key, value) + + def Field(*, default: Any, description: str = "") -> Any: # type: ignore[no-redef] + del description + return default + + +from context_compiler import ( + DECISION_CLARIFY, + DECISION_PASSTHROUGH, + DECISION_UPDATE, + POLICY_PROHIBIT, + POLICY_USE, + State, + create_engine, + get_clarify_prompt, + get_decision_state, + get_policy_items, + get_premise_value, + is_clarify, + is_passthrough, + is_update, +) +from context_compiler.engine import Engine +from context_compiler.observability import build_compact_trace_text +from context_compiler_directive_drafter import ( + PREPROCESS_OUTCOME_DIRECTIVE, + parse_preprocessor_output, + preprocess_heuristic, + render_prompt, +) + +logger = logging.getLogger(__name__) + +_CC_MARKER = "[[cc_state]]" +_ENGINES_BY_CHAT_KEY: dict[str, Engine] = {} +# Example-only in-memory checkpoint store. +# This keeps continuation state only for the current process lifetime. +# Real deployments should persist checkpoints externally (DB/Redis/etc.), +# or restart continuity for pending flows will be lost. +_CHECKPOINTS_BY_CHAT_KEY: dict[str, str] = {} +_PROMPTS_DIR = files("context_compiler_directive_drafter").joinpath("prompts") + + +def _is_directive_shaped_input(message: str) -> bool: + normalized = re.sub(r"\s+", " ", message.strip()).lower() + return ( + normalized.startswith("use") + or normalized.startswith("prohibit") + or normalized.startswith("remove policy") + or normalized.startswith("set premise") + or normalized.startswith("change premise") + or normalized.startswith("clear") + or normalized.startswith("reset") + ) + + +def _prompt_file_path(profile: str) -> Traversable: + # Runtime prompt selection for fallback precompilation: + # - default: most instruction-following models + # - llama: models that need tighter prompt guidance + if profile == "llama": + return _PROMPTS_DIR.joinpath("llama.txt") + return _PROMPTS_DIR.joinpath("default.txt") + + +def _resolve_chat_key( + user: dict[str, Any], + chat_id: str | None, + metadata: dict[str, Any] | None, +) -> str: + if chat_id: + return chat_id + if isinstance(metadata, dict): + metadata_chat_id = metadata.get("chat_id") + if isinstance(metadata_chat_id, str) and metadata_chat_id: + return metadata_chat_id + user_id = str(user["id"]) + return f"no-chat-id:{user_id}" + + +def _extract_latest_user_text(messages: list[dict[str, Any]]) -> str | None: + for message in reversed(messages): + if message.get("role") != "user": + continue + content = message.get("content") + if isinstance(content, str): + return content + return None + return None + + +def _has_pending_clarification(engine: Engine) -> bool: + return engine.has_pending_clarification() + + +def _render_compiler_state_block(state: State) -> str: + lines: list[str] = [_CC_MARKER] + + premise = get_premise_value(state) + if premise is not None: + lines.append(f"Premise: {premise}") + + use_items = sorted(get_policy_items(state, POLICY_USE)) + if use_items: + lines.append("Use: " + ", ".join(use_items)) + + prohibit_items = sorted(get_policy_items(state, POLICY_PROHIBIT)) + if prohibit_items: + lines.append("Prohibit: " + ", ".join(prohibit_items)) + + return "\n".join(lines) + + +def _render_show_state_summary(engine: Engine) -> str: + premise = get_premise_value(engine.state) + use_items = sorted(get_policy_items(engine.state, POLICY_USE)) + prohibit_items = sorted(get_policy_items(engine.state, POLICY_PROHIBIT)) + pending = engine.has_pending_clarification() + + use_text = ", ".join(use_items) if use_items else "none" + prohibit_text = ", ".join(prohibit_items) if prohibit_items else "none" + premise_text = premise if premise is not None else "none" + pending_text = "yes" if pending else "no" + + return ( + f"Premise: {premise_text}\n" + f"Use: {use_text}\n" + f"Prohibit: {prohibit_text}\n" + f"Pending clarification: {pending_text}" + ) + + +def _replace_compiler_system_message( + messages: list[dict[str, Any]], + rendered_state_block: str, +) -> list[dict[str, Any]]: + filtered_messages: list[dict[str, Any]] = [] + last_system_index = -1 + + for message in messages: + role = message.get("role") + content = message.get("content") + if role == "system" and isinstance(content, str) and content.startswith(_CC_MARKER): + continue + + filtered_messages.append(message) + if role == "system": + last_system_index = len(filtered_messages) - 1 + + insert_at = last_system_index + 1 if last_system_index >= 0 else 0 + compiler_message: dict[str, Any] = {"role": "system", "content": rendered_state_block} + return [ + *filtered_messages[:insert_at], + compiler_message, + *filtered_messages[insert_at:], + ] + + +def _normalize_state(value: object) -> State: + if isinstance(value, dict): + return cast(State, value) + return {"premise": None, "policies": {}, "version": 2} + + +def _has_non_empty_authoritative_state(state: State) -> bool: + if get_premise_value(state) is not None: + return True + return bool(get_policy_items(state, POLICY_USE) or get_policy_items(state, POLICY_PROHIBIT)) + + +def _build_compact_trace_text( + *, + decision: object, + state_before: object, + state_after: object, + llm_called: bool, + state_injected: str, +) -> str: + return build_compact_trace_text( + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=llm_called, + state_injected=state_injected, + ) + + +def _strip_trace_block_from_text(content: str) -> str: + marker = "Context Compiler trace" + index = content.find(marker) + if index < 0: + return content + return content[:index].rstrip() + + +def _strip_trace_blocks_from_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + cleaned: list[dict[str, Any]] = [] + for message in messages: + msg = dict(message) + content = msg.get("content") + if isinstance(content, str): + msg["content"] = _strip_trace_block_from_text(content) + cleaned.append(msg) + return cleaned + + +def _build_forward_messages( + raw_messages: object, + *, + state: State | None = None, +) -> list[dict[str, Any]]: + """Build forwarded messages with trace stripping and optional state injection.""" + messages = ( + _strip_trace_blocks_from_messages([msg for msg in raw_messages if isinstance(msg, dict)]) + if isinstance(raw_messages, list) + else [] + ) + if state is not None and _has_non_empty_authoritative_state(state): + return _replace_compiler_system_message( + messages, + _render_compiler_state_block(state), + ) + return messages + + +def _strip_existing_trace_from_chunk(chunk: object) -> object: + if isinstance(chunk, str): + return _strip_trace_block_from_text(chunk) + if isinstance(chunk, bytes): + decoded = chunk.decode("utf-8", errors="ignore") + cleaned = _strip_trace_block_from_text(decoded) + return cleaned.encode("utf-8") + return chunk + + +def _render_item_label(value: str) -> str: + return re.sub(r"\s+", " ", value).strip().lower() + + +def _near_miss_directive_clarify(value: str) -> str | None: + normalized = re.sub(r"\s+", " ", value.strip()) + lower = normalized.lower() + + if lower in {"reset premise", "reset premises", "clear premises"}: + return "Unknown directive.\nUse 'clear premise' or 'reset policies'." + if lower.startswith("set premise to "): + return "Invalid premise syntax.\nUse 'set premise '." + if lower.startswith("change premise ") and not lower.startswith("change premise to "): + return "Invalid premise syntax.\nUse 'change premise to '." + return None + + +def _summarize_update_from_input(user_input: str) -> str: + normalized = re.sub(r"\s+", " ", user_input.strip()) + lower = normalized.lower() + + if lower == "clear state": + return "State cleared." + if lower == "clear premise": + return "Premise cleared." + if lower == "reset policies": + return "Policies reset." + + replacement_match = re.match( + r"^use\s+(.+?)\s+instead\s+of\s+(.+)$", normalized, flags=re.IGNORECASE + ) + if replacement_match is not None: + item = _render_item_label(replacement_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + use_match = re.match(r"^use\s+(.+)$", normalized, flags=re.IGNORECASE) + if use_match is not None: + item = _render_item_label(use_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Use {item}." + + prohibit_match = re.match(r"^prohibit\s+(.+)$", normalized, flags=re.IGNORECASE) + if prohibit_match is not None: + item = _render_item_label(prohibit_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Prohibit {item}." + + remove_policy_match = re.match(r"^remove\s+policy\s+(.+)$", normalized, flags=re.IGNORECASE) + if remove_policy_match is not None: + item = _render_item_label(remove_policy_match.group(1).rstrip(" .!?")) + if item: + return f"State updated: Removed policy {item}." + + return "State updated." + + +def _is_administrative_update_input(user_input: str) -> bool: + normalized = re.sub(r"\s+", " ", user_input.strip()).lower() + return ( + normalized == "clear state" + or normalized == "clear premise" + or normalized == "reset policies" + or normalized.startswith("remove policy ") + ) + + +def _extract_completion_content(response: object) -> str | None: + choices_attr = getattr(response, "choices", None) + if isinstance(choices_attr, list) and choices_attr: + first_choice = choices_attr[0] + message_attr = getattr(first_choice, "message", None) + content_attr = getattr(message_attr, "content", None) + if isinstance(content_attr, str): + return content_attr + + if isinstance(response, dict): + choices = response.get("choices") + if isinstance(choices, list) and choices: + first_choice = choices[0] + if isinstance(first_choice, dict): + message = first_choice.get("message") + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + return content + + return None + + +def _normalize_model_id(value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + +def _is_truthy_bool(value: object) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "on"}: + return True + if normalized in {"false", "0", "off"}: + return False + return False + + +class Pipe: + """Map Context Compiler decisions into Open WebUI pipe behavior. + + This variant adds a directive-drafter stage before ``engine.step(...)``: + heuristic first, then Open WebUI-native LLM fallback. + Update decisions return deterministic local acknowledgement (no model call). + """ + + class Valves(BaseModel): + BASE_MODEL_ID: str = Field( + default="", + description=( + "Required Open WebUI model id used for forwarding. Must exactly match a " + "configured model id in Open WebUI (not arbitrary text), for example: " + "llama3.1:8b." + ), + ) + PREPROCESSOR_MODEL_ID: str | None = Field( + default=None, + description=( + "Optional model id for fallback precompilation (defaults to BASE_MODEL_ID)." + ), + ) + PREPROCESSOR_PROMPT_PROFILE: Literal["default", "llama"] = Field( + default="default", + description="Prompt profile for LLM fallback precompilation.", + ) + ALLOW_MISSING_BASE_MODEL_FOR_DEBUG: bool = Field( + default=False, + description="Allow missing BASE_MODEL_ID for debug/testing only.", + ) + SHOW_CONTEXT_COMPILER_TRACE: bool = Field( + default=False, + description="Include concise Context Compiler trace text in responses.", + ) + + def __init__(self) -> None: + self.valves = self.Valves() + + def _allow_missing_base_model_for_debug(self) -> bool: + return _is_truthy_bool(getattr(self.valves, "ALLOW_MISSING_BASE_MODEL_FOR_DEBUG", False)) + + def _trace_enabled(self) -> bool: + return bool(getattr(self.valves, "SHOW_CONTEXT_COMPILER_TRACE", False)) + + def _append_trace_to_response(self, response: Any, trace_text: str) -> Any: + body_iterator = getattr(response, "body_iterator", None) + if body_iterator is not None and callable(getattr(body_iterator, "__aiter__", None)): + response.body_iterator = self._append_trace_to_stream( + cast(AsyncIterator[object], body_iterator), trace_text + ) + return response + aiter = getattr(response, "__aiter__", None) + if callable(aiter): + return self._append_trace_to_stream(cast(AsyncIterator[object], response), trace_text) + if isinstance(response, str): + cleaned = _strip_trace_block_from_text(response) + return f"{cleaned}\n\n{trace_text}" + if isinstance(response, dict): + choices = response.get("choices") + if isinstance(choices, list) and choices: + first = choices[0] + if isinstance(first, dict): + message = first.get("message") + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + cleaned = _strip_trace_block_from_text(content) + message["content"] = f"{cleaned}\n\n{trace_text}" + return response + choices_attr = getattr(response, "choices", None) + if isinstance(choices_attr, list) and choices_attr: + first_choice = choices_attr[0] + message_attr = getattr(first_choice, "message", None) + content_attr = getattr(message_attr, "content", None) + if message_attr is not None and isinstance(content_attr, str): + cleaned = _strip_trace_block_from_text(content_attr) + message_attr.content = f"{cleaned}\n\n{trace_text}" + return response + return response + + def _append_trace_to_stream( + self, stream: AsyncIterator[object], trace_text: str + ) -> AsyncIterator[object]: + async def _wrapped() -> AsyncIterator[object]: + chunk_type: type[str] | type[bytes] | None = None + saw_done = False + trace_json = json.dumps({"choices": [{"delta": {"content": f"\n\n{trace_text}"}}]}) + trace_event = f"data: {trace_json}\n\n" + + def _matches_done(value: str) -> bool: + normalized = value.strip() + return normalized == "data: [DONE]" or normalized == "[DONE]" + + async for chunk in stream: + if chunk_type is None: + if isinstance(chunk, bytes): + chunk_type = bytes + elif isinstance(chunk, str): + chunk_type = str + if isinstance(chunk, bytes): + decoded = chunk.decode("utf-8", errors="ignore") + if _matches_done(decoded): + saw_done = True + yield trace_event.encode("utf-8") + yield chunk + continue + elif isinstance(chunk, str) and _matches_done(chunk): + saw_done = True + yield trace_event + yield chunk + continue + yield _strip_existing_trace_from_chunk(chunk) + if saw_done: + return + suffix = f"\n\n{trace_text}" + if chunk_type is bytes: + yield suffix.encode("utf-8") + else: + yield suffix + + return _wrapped() + + def _with_trace( + self, + response: Any, + *, + original_input: str, + compiler_input: str, + decision: object, + state_before: object, + state_after: object, + llm_called: bool, + preprocessor_output: str | None = None, + state_injected: str = "no", + ) -> Any: + if not self._trace_enabled(): + return response + del original_input, compiler_input, preprocessor_output + trace_text = _build_compact_trace_text( + decision=decision, + state_before=state_before, + state_after=state_after, + llm_called=llm_called, + state_injected=state_injected, + ) + return self._append_trace_to_response(response, trace_text) + + def _is_model_not_found_text(self, value: object) -> bool: + if not isinstance(value, str): + return False + return "model not found" in value.lower() + + def _contains_model_not_found(self, value: object) -> bool: + if self._is_model_not_found_text(value): + return True + if isinstance(value, dict): + return any(self._contains_model_not_found(v) for v in value.values()) + if isinstance(value, list): + return any(self._contains_model_not_found(v) for v in value) + return False + + def _normalize_forward_error(self, response: Any) -> str | None: + if self._contains_model_not_found(response): + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID is invalid or not " + "configured in Open WebUI. Configure a valid model id in " + "Admin Panel → Settings → Models." + ) + return None + + def _normalize_forward_exception(self, exc: Exception) -> str | None: + detail = getattr(exc, "detail", None) + if self._contains_model_not_found(detail) or self._contains_model_not_found(str(exc)): + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID is invalid or not " + "configured in Open WebUI. Configure a valid model id in " + "Admin Panel → Settings → Models." + ) + return None + + def _normalize_preprocessor_error(self, response: Any) -> str | None: + if self._contains_model_not_found(response): + return ( + "Context Compiler pipe misconfigured: PREPROCESSOR_MODEL_ID is invalid or " + "not configured in Open WebUI. Configure a valid model id in " + "Admin Panel → Settings → Models." + ) + return None + + def _normalize_preprocessor_exception(self, exc: Exception) -> str | None: + detail = getattr(exc, "detail", None) + if self._contains_model_not_found(detail) or self._contains_model_not_found(str(exc)): + return ( + "Context Compiler pipe misconfigured: PREPROCESSOR_MODEL_ID is invalid or " + "not configured in Open WebUI. Configure a valid model id in " + "Admin Panel → Settings → Models." + ) + return None + + def _resolve_preprocessor_model_id(self, base_model_id: str | None) -> str | None: + preprocessor_model_id = _normalize_model_id(self.valves.PREPROCESSOR_MODEL_ID) + return preprocessor_model_id or base_model_id + + async def _validate_configured_model_ids( + self, + request: Request, + user_payload: dict[str, Any], + *, + base_model_id: str | None, + preprocessor_model_id: str | None, + ) -> str | None: + base_model_id = _normalize_model_id(base_model_id) + preprocessor_model_id = _normalize_model_id(preprocessor_model_id) + # Best-effort preflight: fail closed only for clear missing-model mismatches. + # If model discovery fails, preserve runtime behavior and rely on call-path + # normalization below. + user = Users.get_user_by_id(user_payload["id"]) + if inspect.isawaitable(user): + user = await user + try: + models = await get_all_models(request, user=user) + except Exception: + return None + + known_model_ids: set[str] = set() + if isinstance(models, list): + for model in models: + if not isinstance(model, dict): + continue + model_id = model.get("id") + if isinstance(model_id, str): + known_model_ids.add(model_id) + + if base_model_id and base_model_id not in known_model_ids: + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID was not found " + "in Open WebUI models." + ) + if preprocessor_model_id and preprocessor_model_id not in known_model_ids: + return ( + "Context Compiler pipe misconfigured: PREPROCESSOR_MODEL_ID was not found " + "in Open WebUI models." + ) + return None + + async def _llm_fallback_preprocess( + self, + message: str, + state: State, + *, + request: Request, + user_payload: dict[str, Any], + prompt_profile: str, + model_id: str | None, + ) -> tuple[str | None, str | None]: + model_id = _normalize_model_id(model_id) + if model_id is None: + return None, None + with as_file(_prompt_file_path(prompt_profile)) as prompt_path: + prompt = render_prompt(prompt_path, state) + if prompt is None: + return None, None + + payload: dict[str, Any] = { + "model": model_id, + "stream": False, + "messages": [ + {"role": "system", "content": prompt}, + {"role": "user", "content": message}, + ], + } + user = Users.get_user_by_id(user_payload["id"]) + if inspect.isawaitable(user): + user = await user + try: + response = await generate_chat_completion(request, payload, user) + except Exception as exc: + normalized_exception = self._normalize_preprocessor_exception(exc) + if normalized_exception is not None: + return None, normalized_exception + return None, None + + normalized_error = self._normalize_preprocessor_error(response) + if normalized_error is not None: + return None, normalized_error + + raw_output = _extract_completion_content(response) + parsed = parse_preprocessor_output(raw_output, source_input=message) + if parsed is None: + return None, None + return parsed, None + + async def _preprocess_user_input( + self, + message: str, + state: State, + *, + request: Request, + user_payload: dict[str, Any], + prompt_profile: str, + model_id: str | None, + ) -> tuple[str | None, str | None]: + # Heuristic first for precision, determinism, and low latency. + # If heuristic does not produce a directive, try Open WebUI-native fallback. + heuristic_result = preprocess_heuristic(message) + + if ( + heuristic_result["outcome"] == PREPROCESS_OUTCOME_DIRECTIVE + and heuristic_result["directive"] + ): + parsed = parse_preprocessor_output(heuristic_result["directive"]) + if parsed is not None: + return parsed, None + + if _is_directive_shaped_input(message): + return None, None + + # In debug mode with missing base/preprocessor model ids, skip fallback + # preprocess entirely so we never attempt an empty-model LLM call. + model_id = _normalize_model_id(model_id) + if model_id is None: + return None, None + + return await self._llm_fallback_preprocess( + message, + state, + request=request, + user_payload=user_payload, + prompt_profile=prompt_profile, + model_id=model_id, + ) + + async def _forward_passthrough( + self, + body: dict[str, Any], + user_payload: dict[str, Any], + request: Request, + *, + base_model_id: str | None, + state: State | None = None, + ) -> Any: + if base_model_id is None: + if self._allow_missing_base_model_for_debug(): + return ( + "Context Compiler debug mode: BASE_MODEL_ID is empty; " + "skipping model passthrough." + ) + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID is required " + "(or set ALLOW_MISSING_BASE_MODEL_FOR_DEBUG=true for testing)." + ) + payload = {**body} + payload["model"] = base_model_id + payload["messages"] = _build_forward_messages(body.get("messages"), state=state) + user = Users.get_user_by_id(user_payload["id"]) + if inspect.isawaitable(user): + user = await user + try: + response = await generate_chat_completion(request, payload, user) + except Exception as exc: + normalized_exception = self._normalize_forward_exception(exc) + if normalized_exception is not None: + return normalized_exception + raise + normalized_error = self._normalize_forward_error(response) + if normalized_error is not None: + return normalized_error + return response + + async def _forward_update( + self, + body: dict[str, Any], + user_payload: dict[str, Any], + request: Request, + state: State, + *, + base_model_id: str | None, + ) -> Any: + if base_model_id is None: + if self._allow_missing_base_model_for_debug(): + return ( + "Context Compiler debug mode: BASE_MODEL_ID is empty; " + "skipping model passthrough." + ) + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID is required " + "(or set ALLOW_MISSING_BASE_MODEL_FOR_DEBUG=true for testing)." + ) + payload = {**body} + payload["model"] = base_model_id + + payload["messages"] = _build_forward_messages(body.get("messages"), state=state) + + user = Users.get_user_by_id(user_payload["id"]) + if inspect.isawaitable(user): + user = await user + try: + response = await generate_chat_completion(request, payload, user) + except Exception as exc: + normalized_exception = self._normalize_forward_exception(exc) + if normalized_exception is not None: + return normalized_exception + raise + normalized_error = self._normalize_forward_error(response) + if normalized_error is not None: + return normalized_error + return response + + async def pipe( + self, + body: dict[str, Any], + __user__: dict[str, Any], + __request__: Request, + __chat_id__: str | None = None, + __metadata__: dict[str, Any] | None = None, + ) -> Any: + # Open WebUI integration entrypoint: + # 1) extract latest user input + # 2) run preprocess (heuristic -> LLM fallback) + # 3) pass directive or original input to engine.step(...) + # 4) map decision back to Open WebUI response behavior + raw_messages = body.get("messages") + messages = ( + [msg for msg in raw_messages if isinstance(msg, dict)] + if isinstance(raw_messages, list) + else [] + ) + base_model_id = _normalize_model_id(self.valves.BASE_MODEL_ID) + preprocessor_model_id = _normalize_model_id(self.valves.PREPROCESSOR_MODEL_ID) + effective_preprocessor_model = preprocessor_model_id or base_model_id + current_model_id = str(body.get("model", "")).strip() + + if not base_model_id and not self._allow_missing_base_model_for_debug(): + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID is required " + "(or set ALLOW_MISSING_BASE_MODEL_FOR_DEBUG=true for testing)." + ) + if base_model_id and current_model_id and base_model_id == current_model_id: + return ( + "Context Compiler pipe misconfigured: BASE_MODEL_ID must not match " + "the selected pipe model id to avoid recursive routing." + ) + if ( + effective_preprocessor_model + and current_model_id + and effective_preprocessor_model == current_model_id + ): + return ( + "Context Compiler pipe misconfigured: PREPROCESSOR_MODEL_ID must not " + "match the selected pipe model id to avoid recursive routing." + ) + + preflight_error = await self._validate_configured_model_ids( + __request__, + __user__, + base_model_id=base_model_id, + preprocessor_model_id=effective_preprocessor_model, + ) + if preflight_error is not None: + return preflight_error + + latest_user_text = _extract_latest_user_text(messages) + logger.debug("preprocessor: user_input_found=%s", latest_user_text is not None) + + if latest_user_text is None: + return await self._forward_passthrough( + body, + __user__, + __request__, + base_model_id=base_model_id, + ) + + chat_key = _resolve_chat_key(__user__, __chat_id__, __metadata__) + engine = _ENGINES_BY_CHAT_KEY.get(chat_key) + if engine is None: + engine = create_engine() + checkpoint = _CHECKPOINTS_BY_CHAT_KEY.get(chat_key) + if checkpoint is not None: + engine.import_checkpoint_json(checkpoint) + _ENGINES_BY_CHAT_KEY[chat_key] = engine + + if latest_user_text.strip().lower() == "show state": + return _render_show_state_summary(engine) + + state_before = engine.state + + preprocessd: str | None = None + preprocess_error: str | None = None + if not _has_pending_clarification(engine): + preprocessd, preprocess_error = await self._preprocess_user_input( + latest_user_text, + engine.state, + request=__request__, + user_payload=__user__, + prompt_profile=self.valves.PREPROCESSOR_PROMPT_PROFILE, + model_id=effective_preprocessor_model, + ) + if preprocess_error is not None: + return preprocess_error + + logger.debug("preprocessor: preprocessd=%r", preprocessd) + # Preserve core behavior: if preprocess yields no directive, use raw user + # text so the compiler still decides clarify/passthrough/update. + compile_input = preprocessd if preprocessd is not None else latest_user_text + + logger.debug("preprocessor: engine_input=%r", compile_input) + decision = engine.step(compile_input) + if is_clarify(decision): + kind = DECISION_CLARIFY + elif is_update(decision): + kind = DECISION_UPDATE + else: + kind = DECISION_PASSTHROUGH + logger.debug("preprocessor: decision=%s", kind) + near_miss_prompt = _near_miss_directive_clarify(latest_user_text) + state_after = get_decision_state(decision) + if state_after is None: + state_after = engine.state + + if is_clarify(decision): + _CHECKPOINTS_BY_CHAT_KEY[chat_key] = engine.export_checkpoint_json() + return self._with_trace( + near_miss_prompt or get_clarify_prompt(decision) or "", + original_input=latest_user_text, + compiler_input=compile_input, + decision=decision, + state_before=state_before, + state_after=state_after, + preprocessor_output=preprocessd, + llm_called=False, + ) + if near_miss_prompt is not None and is_passthrough(decision): + return self._with_trace( + near_miss_prompt, + original_input=latest_user_text, + compiler_input=compile_input, + decision={"kind": DECISION_CLARIFY, "prompt_to_user": near_miss_prompt}, + state_before=state_before, + state_after=state_after, + preprocessor_output=preprocessd, + llm_called=False, + ) + if is_passthrough(decision): + compiled_state = _normalize_state(state_after) + state_injected = "yes" if _has_non_empty_authoritative_state(compiled_state) else "no" + response = await self._forward_passthrough( + body, + __user__, + __request__, + base_model_id=base_model_id, + state=compiled_state, + ) + return self._with_trace( + response, + original_input=latest_user_text, + compiler_input=compile_input, + decision=decision, + state_before=state_before, + state_after=state_after, + preprocessor_output=preprocessd, + llm_called=base_model_id is not None, + state_injected=state_injected, + ) + if is_update(decision): + _CHECKPOINTS_BY_CHAT_KEY[chat_key] = engine.export_checkpoint_json() + return self._with_trace( + _summarize_update_from_input(compile_input), + original_input=latest_user_text, + compiler_input=compile_input, + decision=decision, + state_before=state_before, + state_after=state_after, + preprocessor_output=preprocessd, + llm_called=False, + ) + + compiled_state = _normalize_state(state_after) + state_injected = "yes" if _has_non_empty_authoritative_state(compiled_state) else "no" + response = await self._forward_passthrough( + body, + __user__, + __request__, + base_model_id=base_model_id, + state=compiled_state, + ) + return self._with_trace( + response, + original_input=latest_user_text, + compiler_input=compile_input, + decision=decision, + state_before=state_before, + state_after=state_after, + preprocessor_output=preprocessd, + llm_called=base_model_id is not None, + state_injected=state_injected, + )