diff --git a/openhands-agent-server/openhands/agent_server/openai/models.py b/openhands-agent-server/openhands/agent_server/openai/models.py index 0fc587b5f6..514bd82caf 100644 --- a/openhands-agent-server/openhands/agent_server/openai/models.py +++ b/openhands-agent-server/openhands/agent_server/openai/models.py @@ -3,13 +3,17 @@ from typing import Literal from openai.types import CompletionUsage, Model -from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice, ChoiceDelta from openai.types.chat.chat_completion_message import ChatCompletionMessage from pydantic import BaseModel, ConfigDict OpenAIChatCompletionChoice = Choice +OpenAIChatCompletionChunk = ChatCompletionChunk +OpenAIChatCompletionChunkChoice = ChunkChoice +OpenAIChatCompletionChunkChoiceDelta = ChoiceDelta OpenAIChatCompletionResponse = ChatCompletion OpenAIModel = Model OpenAIResponseMessage = ChatCompletionMessage @@ -29,16 +33,23 @@ class OpenAIContentPart(BaseModel): class OpenAIChatMessage(BaseModel): - role: Literal["system", "user", "assistant", "tool"] + role: Literal["system", "developer", "user", "assistant", "tool"] content: str | list[OpenAIContentPart] | None = None model_config = ConfigDict(extra="ignore") +class OpenAIStreamOptions(BaseModel): + include_usage: bool = False + + model_config = ConfigDict(extra="ignore") + + class OpenAIChatCompletionRequest(BaseModel): model: str messages: list[OpenAIChatMessage] stream: bool = False + stream_options: OpenAIStreamOptions | None = None model_config = ConfigDict(extra="ignore") diff --git a/openhands-agent-server/openhands/agent_server/openai/router.py b/openhands-agent-server/openhands/agent_server/openai/router.py index b3e3118e7e..a549a2c90f 100644 --- a/openhands-agent-server/openhands/agent_server/openai/router.py +++ b/openhands-agent-server/openhands/agent_server/openai/router.py @@ -4,6 +4,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status +from fastapi.responses import StreamingResponse from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer from openhands.agent_server.config import Config @@ -15,6 +16,7 @@ OpenAIModelListResponse, ) from openhands.agent_server.openai.service import ( + iter_openai_chat_completion_sse, list_openai_models, run_chat_completion, ) @@ -84,12 +86,26 @@ async def create_chat_completion( UUID | None, Header(alias="X-OpenHands-ServerConversation-ID") ] = None, conversation_service: ConversationService = Depends(get_conversation_service), -) -> OpenAIChatCompletionResponse: +) -> OpenAIChatCompletionResponse | StreamingResponse: result = await run_chat_completion( - request=body, + request=body.model_copy(update={"stream": False}) if body.stream else body, config=_get_config(request), conversation_service=conversation_service, reusable_conversation_id=x_openhands_server_conversation_id, ) - response.headers["X-OpenHands-ServerConversation-ID"] = str(result.conversation_id) + conversation_id = str(result.conversation_id) + if body.stream: + include_usage = ( + body.stream_options is not None and body.stream_options.include_usage + ) + return StreamingResponse( + iter_openai_chat_completion_sse( + result.response, + include_usage=include_usage, + ), + media_type="text/event-stream", + headers={"X-OpenHands-ServerConversation-ID": conversation_id}, + ) + + response.headers["X-OpenHands-ServerConversation-ID"] = conversation_id return result.response diff --git a/openhands-agent-server/openhands/agent_server/openai/service.py b/openhands-agent-server/openhands/agent_server/openai/service.py index 9e314764c6..a5599923a4 100644 --- a/openhands-agent-server/openhands/agent_server/openai/service.py +++ b/openhands-agent-server/openhands/agent_server/openai/service.py @@ -1,7 +1,9 @@ """Service logic for the OpenAI-compatible agent-server gateway.""" import asyncio +import json import time +from collections.abc import Iterator from dataclasses import dataclass from uuid import UUID, uuid4 @@ -12,6 +14,9 @@ from openhands.agent_server.event_service import EventService from openhands.agent_server.openai.models import ( OpenAIChatCompletionChoice, + OpenAIChatCompletionChunk, + OpenAIChatCompletionChunkChoice, + OpenAIChatCompletionChunkChoiceDelta, OpenAIChatCompletionRequest, OpenAIChatCompletionResponse, OpenAIChatMessage, @@ -168,7 +173,7 @@ def _latest_user_message(messages: list[OpenAIChatMessage]) -> OpenAIChatMessage def _system_text(messages: list[OpenAIChatMessage]) -> str: text_parts: list[str] = [] for message in messages: - if message.role != "system": + if message.role not in {"system", "developer"}: continue text = _message_text(message) if text: @@ -287,6 +292,81 @@ def _openai_usage_from_state(state: ConversationState) -> OpenAIUsage: ) +def _openai_stream_event(payload: OpenAIChatCompletionChunk) -> str: + data = payload.model_dump(mode="json", exclude_none=True) + return f"data: {json.dumps(data, separators=(',', ':'))}\n\n" + + +def iter_openai_chat_completion_sse( + response: OpenAIChatCompletionResponse, + *, + include_usage: bool, +) -> Iterator[str]: + created = int(response.created) + completion_id = response.id + model = response.model + content = response.choices[0].message.content + finish_reason = response.choices[0].finish_reason + + yield _openai_stream_event( + OpenAIChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created, + model=model, + choices=[ + OpenAIChatCompletionChunkChoice( + index=0, + delta=OpenAIChatCompletionChunkChoiceDelta(role="assistant"), + finish_reason=None, + ) + ], + ) + ) + yield _openai_stream_event( + OpenAIChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created, + model=model, + choices=[ + OpenAIChatCompletionChunkChoice( + index=0, + delta=OpenAIChatCompletionChunkChoiceDelta(content=content), + finish_reason=None, + ) + ], + ) + ) + yield _openai_stream_event( + OpenAIChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created, + model=model, + choices=[ + OpenAIChatCompletionChunkChoice( + index=0, + delta=OpenAIChatCompletionChunkChoiceDelta(), + finish_reason=finish_reason, + ) + ], + ) + ) + if include_usage: + yield _openai_stream_event( + OpenAIChatCompletionChunk( + id=completion_id, + object="chat.completion.chunk", + created=created, + model=model, + choices=[], + usage=response.usage, + ) + ) + yield "data: [DONE]\n\n" + + async def list_openai_models() -> OpenAIModelListResponse: try: profiles = LLMProfileStore().list_summaries() diff --git a/tests/cross/test_remote_conversation_live_server.py b/tests/cross/test_remote_conversation_live_server.py index b61cdb131f..05c2589de0 100644 --- a/tests/cross/test_remote_conversation_live_server.py +++ b/tests/cross/test_remote_conversation_live_server.py @@ -678,6 +678,55 @@ def test_openai_chat_completions_gateway_over_real_server( "content": "Hello from patched LLM", } + from openai import OpenAI + + openai_client = OpenAI( + api_key="unused", + base_url=f"{env['host']}/v1", + timeout=10, + ) + stream = openai_client.chat.completions.create( + model="openhands_smoke", + messages=[ + {"role": "developer", "content": "Answer tersely."}, + {"role": "user", "content": "Say hello as a stream."}, + ], + stream=True, + stream_options={"include_usage": True}, + user="compat-test-user", + ) + chunks = list(stream) + streamed_text = "".join( + chunk.choices[0].delta.content or "" + for chunk in chunks + if chunk.choices + ) + usage_chunks = [chunk.usage for chunk in chunks if chunk.usage] + assert streamed_text == "Hello from patched LLM" + assert usage_chunks[-1].prompt_tokens == 7 + assert usage_chunks[-1].completion_tokens == 5 + assert usage_chunks[-1].total_tokens == 12 + + stream = openai_client.chat.completions.create( + model="openhands_smoke", + messages=[ + { + "role": "user", + "content": "Say hello as a default stream.", + }, + ], + stream=True, + ) + chunks = list(stream) + streamed_text = "".join( + chunk.choices[0].delta.content or "" + for chunk in chunks + if chunk.choices + ) + usage_chunks = [chunk.usage for chunk in chunks if chunk.usage] + assert streamed_text == "Hello from patched LLM" + assert usage_chunks == [] + def test_openai_gateway_replays_frozen_llm_fixtures( tmp_path: Path, monkeypatch: pytest.MonkeyPatch