diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 291a684..08f058f 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -911,9 +911,7 @@ def _chat_request_to_dict(chat_request: CreateChatCompletionRequest) -> dict: messages.append( { "role": "user", - "content": msg.content - if isinstance(msg.content, str) - else str(msg.content), + "content": msg.content, } ) elif isinstance(msg, ChatCompletionRequestAssistantMessage): diff --git a/tee_gateway/llm_backend.py b/tee_gateway/llm_backend.py index 0e80d22..c0da980 100644 --- a/tee_gateway/llm_backend.py +++ b/tee_gateway/llm_backend.py @@ -325,6 +325,17 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int return images, len(images) +def _normalize_user_content_parts(content: list) -> list: + """Preserve multimodal user content while tolerating primitive text parts.""" + normalized = [] + for part in content: + if isinstance(part, dict): + normalized.append(part) + else: + normalized.append({"type": "text", "text": str(part)}) + return normalized + + def convert_messages(messages: list) -> List[Any]: """Convert OpenAI-format message objects or dicts to LangChain message objects.""" langchain_messages: List[BaseMessage] = [] @@ -333,13 +344,17 @@ def convert_messages(messages: list) -> List[Any]: # Support both OpenAPI model objects and plain dicts if isinstance(msg, dict): role = msg.get("role", "").lower() - content = msg.get("content", "") or "" + content = msg.get("content", "") + if content is None: + content = "" tool_calls = msg.get("tool_calls") tool_call_id = msg.get("tool_call_id") name = msg.get("name") else: role = getattr(msg, "role", "").lower() - content = getattr(msg, "content", "") or "" + content = getattr(msg, "content", "") + if content is None: + content = "" tool_calls = getattr(msg, "tool_calls", None) tool_call_id = getattr(msg, "tool_call_id", None) name = getattr(msg, "name", None) @@ -348,12 +363,8 @@ def convert_messages(messages: list) -> List[Any]: langchain_messages.append(SystemMessage(content=content)) elif role == "user": - # content may be a string or a list of content parts; handle both if isinstance(content, list): - content = "".join( - part.get("text", "") if isinstance(part, dict) else str(part) - for part in content - ) + content = _normalize_user_content_parts(content) langchain_messages.append(HumanMessage(content=content)) elif role == "assistant": diff --git a/tee_gateway/test/test_tee_core.py b/tee_gateway/test/test_tee_core.py index 41aac33..2411176 100644 --- a/tee_gateway/test/test_tee_core.py +++ b/tee_gateway/test/test_tee_core.py @@ -569,20 +569,30 @@ def test_multi_turn_order_preserved(self): self.assertIsInstance(result[2], AIMessage) def test_user_content_as_list_of_parts(self): - """Multimodal content parts should be concatenated into a single string.""" + """Multimodal content parts should be preserved for vision-capable models.""" + content = [ + {"type": "text", "text": "Hello world"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,abcd"}, + }, + ] result = convert_messages( [ { "role": "user", - "content": [ - {"type": "text", "text": "Hello "}, - {"type": "text", "text": "world"}, - ], + "content": content, } ] ) self.assertIsInstance(result[0], HumanMessage) - self.assertEqual(result[0].content, "Hello world") + self.assertEqual(result[0].content, content) + + def test_empty_user_content_list_is_preserved(self): + """Empty multimodal content lists should not be coerced to empty strings.""" + result = convert_messages([{"role": "user", "content": []}]) + self.assertIsInstance(result[0], HumanMessage) + self.assertEqual(result[0].content, []) def test_full_tool_call_conversation(self): """End-to-end multi-turn with tool use: user → assistant (tool call) → tool result.""" diff --git a/tests/test_structured_outputs.py b/tests/test_structured_outputs.py index 3db565f..71fbd20 100644 --- a/tests/test_structured_outputs.py +++ b/tests/test_structured_outputs.py @@ -9,6 +9,9 @@ from tee_gateway.models.create_chat_completion_request import ( CreateChatCompletionRequest, ) +from tee_gateway.models.chat_completion_request_user_message import ( + ChatCompletionRequestUserMessage, +) class TestResponseFormatParsing(unittest.TestCase): @@ -109,6 +112,26 @@ def test_hash_differs_with_and_without_response_format(self): h2 = json.dumps(chat_request_to_dict(req_json), sort_keys=True) self.assertNotEqual(h1, h2) + def test_hash_dict_preserves_multimodal_user_content(self): + content = [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,abcd"}, + }, + ] + req = CreateChatCompletionRequest( + model="gpt-4.1", + messages=[ChatCompletionRequestUserMessage(role="user", content=content)], + temperature=1.0, + ) + request_dict = chat_request_to_dict(req) + self.assertEqual(request_dict["messages"][0]["content"], content) + + dumped_once = json.dumps(request_dict, sort_keys=True) + dumped_twice = json.dumps(chat_request_to_dict(req), sort_keys=True) + self.assertEqual(dumped_once, dumped_twice) + class TestResponseFormatModelBinding(unittest.TestCase): """Tests that response_format is bound to the model before invocation."""