Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,9 +911,7 @@ def _chat_request_to_dict(chat_request: CreateChatCompletionRequest) -> dict:
messages.append(
{
"role": "user",
"content": msg.content
if isinstance(msg.content, str)
else str(msg.content),
"content": msg.content,
}
Comment on lines 911 to 915
)
elif isinstance(msg, ChatCompletionRequestAssistantMessage):
Expand Down
25 changes: 18 additions & 7 deletions tee_gateway/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,17 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int
return images, len(images)


def _normalize_user_content_parts(content: list) -> list:
"""Preserve multimodal user content while tolerating primitive text parts."""
normalized = []
for part in content:
if isinstance(part, dict):
normalized.append(part)
else:
normalized.append({"type": "text", "text": str(part)})
return normalized


def convert_messages(messages: list) -> List[Any]:
"""Convert OpenAI-format message objects or dicts to LangChain message objects."""
langchain_messages: List[BaseMessage] = []
Expand All @@ -333,13 +344,17 @@ def convert_messages(messages: list) -> List[Any]:
# Support both OpenAPI model objects and plain dicts
if isinstance(msg, dict):
role = msg.get("role", "").lower()
content = msg.get("content", "") or ""
content = msg.get("content", "")
if content is None:
content = ""
tool_calls = msg.get("tool_calls")
tool_call_id = msg.get("tool_call_id")
name = msg.get("name")
else:
role = getattr(msg, "role", "").lower()
content = getattr(msg, "content", "") or ""
content = getattr(msg, "content", "")
if content is None:
content = ""
tool_calls = getattr(msg, "tool_calls", None)
tool_call_id = getattr(msg, "tool_call_id", None)
name = getattr(msg, "name", None)
Expand All @@ -348,12 +363,8 @@ def convert_messages(messages: list) -> List[Any]:
langchain_messages.append(SystemMessage(content=content))

elif role == "user":
# content may be a string or a list of content parts; handle both
if isinstance(content, list):
content = "".join(
part.get("text", "") if isinstance(part, dict) else str(part)
for part in content
)
content = _normalize_user_content_parts(content)
langchain_messages.append(HumanMessage(content=content))
Comment on lines 365 to 368

elif role == "assistant":
Expand Down
22 changes: 16 additions & 6 deletions tee_gateway/test/test_tee_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,20 +569,30 @@ def test_multi_turn_order_preserved(self):
self.assertIsInstance(result[2], AIMessage)

def test_user_content_as_list_of_parts(self):
"""Multimodal content parts should be concatenated into a single string."""
"""Multimodal content parts should be preserved for vision-capable models."""
content = [
{"type": "text", "text": "Hello world"},
{
Comment on lines 571 to +575
"type": "image_url",
"image_url": {"url": "data:image/png;base64,abcd"},
},
]
result = convert_messages(
[
{
"role": "user",
"content": [
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
],
"content": content,
}
]
)
self.assertIsInstance(result[0], HumanMessage)
self.assertEqual(result[0].content, "Hello world")
self.assertEqual(result[0].content, content)

def test_empty_user_content_list_is_preserved(self):
"""Empty multimodal content lists should not be coerced to empty strings."""
result = convert_messages([{"role": "user", "content": []}])
self.assertIsInstance(result[0], HumanMessage)
self.assertEqual(result[0].content, [])

def test_full_tool_call_conversation(self):
"""End-to-end multi-turn with tool use: user → assistant (tool call) → tool result."""
Expand Down
23 changes: 23 additions & 0 deletions tests/test_structured_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from tee_gateway.models.create_chat_completion_request import (
CreateChatCompletionRequest,
)
from tee_gateway.models.chat_completion_request_user_message import (
ChatCompletionRequestUserMessage,
)


class TestResponseFormatParsing(unittest.TestCase):
Expand Down Expand Up @@ -109,6 +112,26 @@ def test_hash_differs_with_and_without_response_format(self):
h2 = json.dumps(chat_request_to_dict(req_json), sort_keys=True)
self.assertNotEqual(h1, h2)

def test_hash_dict_preserves_multimodal_user_content(self):
content = [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,abcd"},
},
]
req = CreateChatCompletionRequest(
model="gpt-4.1",
messages=[ChatCompletionRequestUserMessage(role="user", content=content)],
temperature=1.0,
)
request_dict = chat_request_to_dict(req)
self.assertEqual(request_dict["messages"][0]["content"], content)

dumped_once = json.dumps(request_dict, sort_keys=True)
dumped_twice = json.dumps(chat_request_to_dict(req), sort_keys=True)
self.assertEqual(dumped_once, dumped_twice)


class TestResponseFormatModelBinding(unittest.TestCase):
"""Tests that response_format is bound to the model before invocation."""
Expand Down
Loading