diff --git a/tests/test_cli.py b/tests/test_cli.py index f312c86..6bde552 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -812,3 +812,385 @@ async def test_get_client_auto_discover(self, mock_env): storage_api_url="https://connection.keboola.com", ) assert client == mock_client + + +class TestInfoCommandEdgeCases: + """Tests for the info command with edge case MCP formats.""" + + def test_info_with_single_dict_mcp(self, runner, mock_env): + """Test info command when connectedMcp is a single dict instead of list.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.info = AsyncMock( + return_value=InfoResponse( + timestamp="2025-01-08T12:00:00Z", + uptime=100.0, + appName="kai-backend", + appVersion="1.0.0", + serverVersion="2.0.0", + connectedMcp={"name": "keboola-mcp", "status": "connected"}, + ) + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke(main, ["info"]) + + assert result.exit_code == 0 + assert "keboola-mcp" in result.output + assert "connected" in result.output + + +class TestChatInteractiveMode: + """Tests for interactive chat mode.""" + + def test_chat_interactive_exit(self, runner, mock_env): + """Test interactive mode with 'exit' command.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke(main, ["chat"], input="exit\n") + + assert result.exit_code == 0 + assert "Interactive chat mode" in result.output + assert "Chat ended" in result.output + + def test_chat_interactive_quit(self, runner, mock_env): + """Test interactive mode with 'quit' command.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke(main, ["chat"], input="quit\n") + + assert result.exit_code == 0 + assert "Chat ended" in result.output + + def test_chat_interactive_empty_input_skipped(self, runner, mock_env): + """Test that empty input is skipped in interactive mode.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + + send_call_count = 0 + + async def mock_send_message(chat_id, message): + nonlocal send_call_count + send_call_count += 1 + yield TextEvent(type="text", text="Response") + yield FinishEvent(type="finish", finishReason="stop") + + mock_client.send_message = mock_send_message + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + # Empty line, then a message, then exit + result = runner.invoke(main, ["chat"], input="\nHello\nexit\n") + + assert result.exit_code == 0 + assert send_call_count == 1 # Only the "Hello" message was sent + + def test_chat_interactive_with_message_and_response(self, runner, mock_env): + """Test interactive mode sends message and shows response.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + + async def mock_send_message(chat_id, message): + yield TextEvent(type="text", text="I can help!") + yield FinishEvent(type="finish", finishReason="stop") + + mock_client.send_message = mock_send_message + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke(main, ["chat"], input="Hello\nexit\n") + + assert result.exit_code == 0 + assert "I can help!" in result.output + + +class TestChatV6ApprovalFlow: + """Tests for v6 tool approval flow in the CLI.""" + + def test_chat_v6_approval_auto_approve(self, runner, mock_env): + """Test auto-approve with v6 approval flow (approval_id).""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + + from kai_client.models import ToolApprovalRequestEvent + + async def mock_send_message(chat_id, message): + # Tool starts and waits for approval + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-v6", + toolName="update_descriptions", + state="started", + ) + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-v6", + toolName="update_descriptions", + state="input-available", + input={"descriptions": "new"}, + ) + # v6 approval request event + yield ToolApprovalRequestEvent( + type="tool-approval-request", + approvalId="appr-v6-001", + toolCallId="tool-v6", + ) + + async def mock_approve_tool(chat_id, approval_id, **kwargs): + yield TextEvent(type="text", text="Descriptions updated!") + yield FinishEvent(type="finish", finishReason="stop") + + mock_client.send_message = mock_send_message + mock_client.approve_tool = mock_approve_tool + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke( + main, ["chat", "--auto-approve", "-m", "Update descriptions"] + ) + + assert result.exit_code == 0 + assert "[Auto-approving...]" in result.output + assert "Descriptions updated!" in result.output + + def test_chat_v6_approval_user_approve(self, runner, mock_env): + """Test manual user approval with v6 flow (click.confirm).""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + + from kai_client.models import ToolApprovalRequestEvent + + async def mock_send_message(chat_id, message): + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-v6-2", + toolName="create_config", + state="started", + ) + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-v6-2", + toolName="create_config", + state="input-available", + input={"name": "test"}, + ) + yield ToolApprovalRequestEvent( + type="tool-approval-request", + approvalId="appr-v6-002", + toolCallId="tool-v6-2", + ) + + async def mock_approve_tool(chat_id, approval_id, **kwargs): + yield TextEvent(type="text", text="Config created!") + yield FinishEvent(type="finish", finishReason="stop") + + mock_client.send_message = mock_send_message + mock_client.approve_tool = mock_approve_tool + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + # User types 'y' to confirm + result = runner.invoke( + main, ["chat", "-m", "Create config"], input="y\n" + ) + + assert result.exit_code == 0 + assert "Config created!" in result.output + + def test_chat_v6_approval_user_reject(self, runner, mock_env): + """Test manual user rejection with v6 flow.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + + from kai_client.models import ToolApprovalRequestEvent + + async def mock_send_message(chat_id, message): + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-v6-3", + toolName="delete_bucket", + state="started", + ) + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-v6-3", + toolName="delete_bucket", + state="input-available", + input={"bucket_id": "in.c-test"}, + ) + yield ToolApprovalRequestEvent( + type="tool-approval-request", + approvalId="appr-v6-003", + toolCallId="tool-v6-3", + ) + + async def mock_reject_tool(chat_id, approval_id, **kwargs): + yield TextEvent(type="text", text="OK, not deleting.") + yield FinishEvent(type="finish", finishReason="stop") + + mock_client.send_message = mock_send_message + mock_client.reject_tool = mock_reject_tool + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + # User types 'n' to reject + result = runner.invoke( + main, ["chat", "-m", "Delete bucket"], input="n\n" + ) + + assert result.exit_code == 0 + assert "OK, not deleting." in result.output + + +class TestDisplayToolResultEvents: + """Tests for display_tool_result_events function.""" + + def test_display_tool_output_error(self, runner, mock_env): + """Test that tool-output-error events are displayed.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + + from kai_client.models import ToolOutputErrorEvent + + async def mock_send_message(chat_id, message): + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-err", + toolName="run_job", + state="input-available", + input={"job_id": "123"}, + ) + + async def mock_confirm_tool(chat_id, tool_call_id, tool_name): + yield ToolOutputErrorEvent( + type="tool-output-error", + toolCallId="tool-err", + errorText="Job failed: timeout", + ) + yield TextEvent(type="text", text="The job failed.") + yield FinishEvent(type="finish", finishReason="stop") + + mock_client.send_message = mock_send_message + mock_client.confirm_tool = mock_confirm_tool + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke( + main, ["chat", "--auto-approve", "-m", "Run job"] + ) + + assert result.exit_code == 0 + assert "Tool Error: Job failed: timeout" in result.output + + def test_display_tool_result_json_output(self, runner, mock_env): + """Test display_tool_result_events with json output mode.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.new_chat_id = MagicMock(return_value="test-chat-id") + + async def mock_send_message(chat_id, message): + yield ToolCallEvent( + type="tool-call", + toolCallId="tool-json", + toolName="create_config", + state="input-available", + input={"name": "test"}, + ) + + async def mock_confirm_tool(chat_id, tool_call_id, tool_name): + yield TextEvent(type="text", text="Done") + yield FinishEvent(type="finish", finishReason="stop") + + mock_client.send_message = mock_send_message + mock_client.confirm_tool = mock_confirm_tool + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke( + main, ["chat", "--auto-approve", "--json-output", "-m", "Create"] + ) + + assert result.exit_code == 0 + # All output should be JSON lines + lines = [ln for ln in result.output.strip().split("\n") if ln] + for line in lines: + parsed = json.loads(line) + assert "type" in parsed + + +class TestGetChatDisplayEdgeCases: + """Tests for get-chat display with various part types.""" + + def test_get_chat_no_title(self, runner, mock_env): + """Test get-chat when chat has no title.""" + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.get_chat = AsyncMock( + return_value=ChatDetail( + id="chat-123", + title=None, + messages=[], + ) + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke(main, ["get-chat", "chat-123"]) + + assert result.exit_code == 0 + assert "(no title)" in result.output + + def test_get_chat_with_object_parts(self, runner, mock_env): + """Test get-chat with parts that are objects (have .text / .type attrs).""" + + class FakeTextPart: + text = "Here's the result:" + + class FakeToolPart: + type = "tool-call" + + # Build a ChatDetail with object-style parts (not dicts) + chat = ChatDetail(id="chat-123", title="Test", messages=[]) + msg = Message(id="msg-1", role="assistant", parts=[]) + # Override parts with objects after construction + msg.parts = [FakeTextPart(), FakeToolPart()] # type: ignore + chat.messages = [msg] + + with patch("kai_client.cli.get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.get_chat = AsyncMock(return_value=chat) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_get_client.return_value = mock_client + + result = runner.invoke(main, ["get-chat", "chat-123"]) + + assert result.exit_code == 0 + assert "Here's the result:" in result.output + assert "[tool-call]" in result.output diff --git a/tests/test_client.py b/tests/test_client.py index 635b253..d0a11e1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1532,3 +1532,504 @@ async def test_write_tool_requires_approval( assert pending_tools[0].input is not None +class TestFromStorageApiConnectionError: + """Tests for from_storage_api when connection fails.""" + + @pytest.mark.asyncio + async def test_from_storage_api_connection_error(self, httpx_mock: HTTPXMock): + """Test error when Storage API is unreachable.""" + import httpx as httpx_lib + + from kai_client import KaiConnectionError + + httpx_mock.add_exception( + httpx_lib.RequestError("Connection refused"), + url="https://connection.keboola.com/v2/storage", + ) + + with pytest.raises(KaiConnectionError) as exc_info: + await KaiClient.from_storage_api( + storage_api_token="test-token", + storage_api_url="https://connection.keboola.com", + ) + + assert "Failed to connect to Storage API" in str(exc_info.value) + + +class TestLazyClientInit: + """Tests for lazy client initialization (without context manager).""" + + @pytest.mark.asyncio + async def test_get_client_creates_client_lazily(self, httpx_mock: HTTPXMock): + """Test that _get_client creates client on first use.""" + httpx_mock.add_response( + url="http://localhost:3000/ping", + json={"timestamp": "2025-12-24T16:24:10.641Z"}, + ) + + client = KaiClient( + storage_api_token="token", + storage_api_url="https://connection.keboola.com", + ) + + # No context manager — lazy init + assert client._client is None + await client.ping() + assert client._client is not None + await client.close() + + +class TestNonJsonHttpError: + """Tests for HTTP errors with non-JSON response bodies.""" + + @pytest.mark.asyncio + async def test_non_json_error_response(self, client: KaiClient, httpx_mock: HTTPXMock): + """Test _request when server returns non-JSON error body.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat/chat-123", + status_code=502, + content=b"Bad Gateway", + ) + + async with client: + with pytest.raises(KaiError) as exc_info: + await client.get_chat("chat-123") + + assert "502" in str(exc_info.value) + assert exc_info.value.code == "http:502" + + +class TestStreamRequestErrorEdgeCases: + """Tests for _stream_request error handling edge cases.""" + + @pytest.mark.asyncio + async def test_stream_error_empty_body_uses_reason_phrase( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test streaming error with empty body falls back to reason phrase.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat", + method="POST", + status_code=503, + content=b"", + ) + + async with client: + with pytest.raises(KaiError) as exc_info: + async for _ in client.send_message("chat-123", "Test"): + pass + + assert "503" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_stream_error_non_json_body( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test streaming error with non-JSON body that starts with '{'.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat", + method="POST", + status_code=500, + content=b"{broken json", + ) + + async with client: + with pytest.raises(KaiError) as exc_info: + async for _ in client.send_message("chat-123", "Test"): + pass + + assert "500" in str(exc_info.value) + + +class TestV6ToolApprovalFlow: + """Tests for the Vercel AI SDK v6 tool approval flow.""" + + @pytest.mark.asyncio + async def test_send_tool_approval_response_approve( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test approving a tool via the v6 flow.""" + # Mock get_chat to return assistant message with pending approval + httpx_mock.add_response( + url="http://localhost:3000/api/chat/chat-123", + json={ + "id": "chat-123", + "messages": [ + { + "id": "msg-user", + "role": "user", + "parts": [{"type": "text", "text": "Update descriptions"}], + }, + { + "id": "msg-assistant", + "role": "assistant", + "parts": [ + {"type": "text", "text": "I'll update the descriptions."}, + { + "type": "tool-update_descriptions", + "toolCallId": "call-001", + "state": "approval-required", + "approval": { + "id": "approval-abc", + "approved": None, + }, + "input": {"descriptions": "new desc"}, + }, + ], + }, + ], + }, + ) + + # Mock the streaming response after approval + sse_response = ( + 'data: {"type":"text","text":"Descriptions updated."}\n' + 'data: {"type":"finish","finishReason":"stop"}\n' + ) + httpx_mock.add_response( + url="http://localhost:3000/api/chat", + method="POST", + content=sse_response.encode(), + headers={"content-type": "text/event-stream"}, + ) + + async with client: + events = [] + async for event in client.send_tool_approval_response( + chat_id="chat-123", + approval_id="approval-abc", + approved=True, + ): + events.append(event) + + assert len(events) == 2 + assert events[0].type == "text" + assert events[0].text == "Descriptions updated." + + # Verify the POST payload has the updated assistant message + requests = httpx_mock.get_requests() + post_request = [r for r in requests if r.method == "POST"][0] + body = json.loads(post_request.content) + assert body["id"] == "chat-123" + assert body["message"]["id"] == "msg-assistant" + assert body["message"]["role"] == "assistant" + + # Find the updated tool part + tool_parts = [p for p in body["message"]["parts"] if "approval" in p] + assert len(tool_parts) == 1 + assert tool_parts[0]["state"] == "approval-responded" + assert tool_parts[0]["approval"]["approved"] is True + + @pytest.mark.asyncio + async def test_send_tool_approval_response_reject_with_reason( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test rejecting a tool with a reason via the v6 flow.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat/chat-123", + json={ + "id": "chat-123", + "messages": [ + { + "id": "msg-assistant", + "role": "assistant", + "parts": [ + { + "type": "tool-delete_bucket", + "toolCallId": "call-002", + "state": "approval-required", + "approval": {"id": "approval-def"}, + "input": {"bucket_id": "in.c-test"}, + }, + ], + }, + ], + }, + ) + + sse_response = ( + 'data: {"type":"text","text":"Understood, I won\'t delete it."}\n' + 'data: {"type":"finish","finishReason":"stop"}\n' + ) + httpx_mock.add_response( + url="http://localhost:3000/api/chat", + method="POST", + content=sse_response.encode(), + headers={"content-type": "text/event-stream"}, + ) + + async with client: + events = [] + async for event in client.send_tool_approval_response( + chat_id="chat-123", + approval_id="approval-def", + approved=False, + reason="Too dangerous", + ): + events.append(event) + + # Verify rejection payload + requests = httpx_mock.get_requests() + post_request = [r for r in requests if r.method == "POST"][0] + body = json.loads(post_request.content) + tool_part = body["message"]["parts"][0] + assert tool_part["approval"]["approved"] is False + assert tool_part["approval"]["reason"] == "Too dangerous" + + @pytest.mark.asyncio + async def test_send_tool_approval_response_not_found( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test error when approval ID is not found in chat messages.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat/chat-123", + json={ + "id": "chat-123", + "messages": [ + { + "id": "msg-user", + "role": "user", + "parts": [{"type": "text", "text": "Hello"}], + }, + ], + }, + ) + + async with client: + with pytest.raises(KaiError) as exc_info: + async for _ in client.send_tool_approval_response( + chat_id="chat-123", + approval_id="nonexistent-approval", + approved=True, + ): + pass + + assert exc_info.value.code == "approval:not_found" + + @pytest.mark.asyncio + async def test_send_tool_approval_response_with_branch_id( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test v6 approval with branch_id parameter.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat/chat-123", + json={ + "id": "chat-123", + "messages": [ + { + "id": "msg-assistant", + "role": "assistant", + "parts": [ + { + "type": "tool-create_config", + "toolCallId": "call-003", + "state": "approval-required", + "approval": {"id": "approval-xyz"}, + "input": {"name": "test"}, + }, + ], + }, + ], + }, + ) + + sse_response = 'data: {"type":"finish","finishReason":"stop"}\n' + httpx_mock.add_response( + url="http://localhost:3000/api/chat", + method="POST", + content=sse_response.encode(), + headers={"content-type": "text/event-stream"}, + ) + + async with client: + async for _ in client.send_tool_approval_response( + chat_id="chat-123", + approval_id="approval-xyz", + approved=True, + branch_id=42, + ): + pass + + requests = httpx_mock.get_requests() + post_request = [r for r in requests if r.method == "POST"][0] + body = json.loads(post_request.content) + assert body["branchId"] == 42 + + +class TestApproveRejectToolConvenience: + """Tests for approve_tool and reject_tool convenience methods.""" + + @pytest.mark.asyncio + async def test_approve_tool(self, client: KaiClient, httpx_mock: HTTPXMock): + """Test approve_tool convenience method.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat/chat-123", + json={ + "id": "chat-123", + "messages": [ + { + "id": "msg-a", + "role": "assistant", + "parts": [ + { + "type": "tool-run_job", + "toolCallId": "call-010", + "state": "approval-required", + "approval": {"id": "appr-001"}, + "input": {"job_id": "123"}, + }, + ], + }, + ], + }, + ) + + sse_response = ( + 'data: {"type":"text","text":"Job started."}\n' + 'data: {"type":"finish","finishReason":"stop"}\n' + ) + httpx_mock.add_response( + url="http://localhost:3000/api/chat", + method="POST", + content=sse_response.encode(), + headers={"content-type": "text/event-stream"}, + ) + + async with client: + events = [] + async for event in client.approve_tool( + chat_id="chat-123", + approval_id="appr-001", + ): + events.append(event) + + assert len(events) == 2 + assert events[0].text == "Job started." + + # Verify approved=True was sent + requests = httpx_mock.get_requests() + post_request = [r for r in requests if r.method == "POST"][0] + body = json.loads(post_request.content) + tool_part = body["message"]["parts"][0] + assert tool_part["approval"]["approved"] is True + + @pytest.mark.asyncio + async def test_reject_tool(self, client: KaiClient, httpx_mock: HTTPXMock): + """Test reject_tool convenience method.""" + httpx_mock.add_response( + url="http://localhost:3000/api/chat/chat-123", + json={ + "id": "chat-123", + "messages": [ + { + "id": "msg-a", + "role": "assistant", + "parts": [ + { + "type": "tool-deploy_data_app", + "toolCallId": "call-020", + "state": "approval-required", + "approval": {"id": "appr-002"}, + "input": {"app_id": "456"}, + }, + ], + }, + ], + }, + ) + + sse_response = ( + 'data: {"type":"text","text":"OK, cancelled."}\n' + 'data: {"type":"finish","finishReason":"stop"}\n' + ) + httpx_mock.add_response( + url="http://localhost:3000/api/chat", + method="POST", + content=sse_response.encode(), + headers={"content-type": "text/event-stream"}, + ) + + async with client: + events = [] + async for event in client.reject_tool( + chat_id="chat-123", + approval_id="appr-002", + reason="Not needed", + ): + events.append(event) + + assert events[0].text == "OK, cancelled." + + requests = httpx_mock.get_requests() + post_request = [r for r in requests if r.method == "POST"][0] + body = json.loads(post_request.content) + tool_part = body["message"]["parts"][0] + assert tool_part["approval"]["approved"] is False + assert tool_part["approval"]["reason"] == "Not needed" + + +class TestGetHistoryEndingBefore: + """Tests for get_history with ending_before parameter.""" + + @pytest.mark.asyncio + async def test_get_history_with_ending_before( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test backward pagination with ending_before.""" + httpx_mock.add_response( + url="http://localhost:3000/api/history?limit=10&ending_before=chat-10", + json={"chats": [{"id": "chat-9"}], "hasMore": True}, + ) + + async with client: + history = await client.get_history(limit=10, ending_before="chat-10") + + assert len(history.chats) == 1 + request = httpx_mock.get_request() + assert "ending_before=chat-10" in str(request.url) + + +class TestParseVoteIsUpvotedFormat: + """Tests for _parse_vote with isUpvoted response format.""" + + def test_parse_vote_is_upvoted_true(self): + """Test parsing vote in isUpvoted format (upvoted).""" + data = {"chatId": "chat-1", "messageId": "msg-1", "isUpvoted": True} + vote = KaiClient._parse_vote(data) + assert vote.type == "up" + assert vote.chat_id == "chat-1" + assert vote.message_id == "msg-1" + + def test_parse_vote_is_upvoted_false(self): + """Test parsing vote in isUpvoted format (downvoted).""" + data = {"chatId": "chat-1", "messageId": "msg-1", "isUpvoted": False} + vote = KaiClient._parse_vote(data) + assert vote.type == "down" + + def test_parse_vote_is_upvoted_missing(self): + """Test parsing vote in isUpvoted format when field is missing.""" + data = {"chatId": "chat-1", "messageId": "msg-1"} + vote = KaiClient._parse_vote(data) + assert vote.type == "down" # Default when isUpvoted missing + + @pytest.mark.asyncio + async def test_get_votes_is_upvoted_format( + self, client: KaiClient, httpx_mock: HTTPXMock + ): + """Test get_votes when API returns isUpvoted format.""" + httpx_mock.add_response( + url="http://localhost:3000/api/vote?chatId=chat-123", + json=[ + {"chatId": "chat-123", "messageId": "msg-1", "isUpvoted": True}, + {"chatId": "chat-123", "messageId": "msg-2", "isUpvoted": False}, + ], + ) + + async with client: + votes = await client.get_votes("chat-123") + + assert len(votes) == 2 + assert votes[0].type == "up" + assert votes[1].type == "down" + + diff --git a/tests/test_sse.py b/tests/test_sse.py index 514e164..2c1eb8e 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -1,18 +1,24 @@ """Tests for SSE stream parser.""" +from unittest.mock import AsyncMock +import httpx +import pytest + +from kai_client.exceptions import KaiStreamError from kai_client.models import ( ErrorEvent, FinishEvent, StepStartEvent, TextEvent, + ToolApprovalRequestEvent, ToolCallEvent, ToolOutputErrorEvent, UnknownEvent, UsageEvent, UsageInfo, ) -from kai_client.sse import SSEStreamParser, parse_sse_event +from kai_client.sse import SSEStreamParser, parse_sse_event, parse_sse_stream class TestParseSSEEvent: @@ -683,3 +689,309 @@ def test_complex_conversation(self): assert parser.finish_reason == "stop" +class TestToolApprovalEvents: + """Tests for tool approval event parsing.""" + + def test_parse_tool_approval_request_event(self): + """Test parsing tool-approval-request event.""" + data = { + "type": "tool-approval-request", + "approvalId": "approval-abc-123", + "toolCallId": "call-456", + } + event = parse_sse_event(data) + assert isinstance(event, ToolApprovalRequestEvent) + assert event.type == "tool-approval-request" + assert event.approval_id == "approval-abc-123" + assert event.tool_call_id == "call-456" + + def test_parse_tool_approval_request_minimal(self): + """Test parsing tool-approval-request with minimal data.""" + data = {"type": "tool-approval-request"} + event = parse_sse_event(data) + assert isinstance(event, ToolApprovalRequestEvent) + assert event.approval_id == "" + assert event.tool_call_id == "" + + def test_parse_tool_call_with_approval(self): + """Test tool-input-available with approval metadata.""" + data = { + "type": "tool-input-available", + "toolCallId": "call-789", + "toolName": "update_descriptions", + "input": {"descriptions": "new desc"}, + "approval": { + "id": "appr-xyz", + "approved": None, + }, + } + event = parse_sse_event(data) + assert isinstance(event, ToolCallEvent) + assert event.approval is not None + assert event.approval.id == "appr-xyz" + assert event.approval.approved is None + + def test_parse_tool_call_with_approval_approved(self): + """Test tool call with approval that has been approved.""" + data = { + "type": "tool-call", + "toolCallId": "call-100", + "toolName": "create_config", + "state": "output-available", + "approval": { + "id": "appr-100", + "approved": True, + "reason": "User approved", + }, + } + event = parse_sse_event(data) + assert isinstance(event, ToolCallEvent) + assert event.approval is not None + assert event.approval.approved is True + assert event.approval.reason == "User approved" + + def test_parse_tool_call_with_empty_approval(self): + """Test tool call with empty approval dict (no id).""" + data = { + "type": "tool-call", + "toolCallId": "call-200", + "toolName": "get_tables", + "state": "started", + "approval": {}, # Empty — should result in None + } + event = parse_sse_event(data) + assert isinstance(event, ToolCallEvent) + assert event.approval is None + + +class TestParseSSEStream: + """Tests for the parse_sse_stream async generator.""" + + @pytest.mark.asyncio + async def test_parse_basic_stream(self): + """Test parsing a basic SSE stream.""" + lines = [ + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + assert events[0].type == "text" + assert events[0].text == "Hello" + assert events[1].type == "finish" + + @pytest.mark.asyncio + async def test_parse_stream_skips_empty_lines(self): + """Test that empty lines are skipped.""" + lines = [ + "", + 'data: {"type":"text","text":"Hello"}', + "", + "", + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_parse_stream_skips_comments(self): + """Test that SSE comments (lines starting with :) are skipped.""" + lines = [ + ": this is a comment", + 'data: {"type":"text","text":"Hello"}', + ": another comment", + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_parse_stream_skips_empty_data(self): + """Test that empty data lines are skipped.""" + lines = [ + "data: ", + "data: ", + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_parse_stream_handles_done_marker(self): + """Test that [DONE] termination marker is skipped.""" + lines = [ + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"finish","finishReason":"stop"}', + "data: [DONE]", + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_parse_stream_ignores_event_field(self): + """Test that 'event:' lines are ignored without error.""" + lines = [ + "event: message", + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_parse_stream_ignores_id_field(self): + """Test that 'id:' lines are ignored without error.""" + lines = [ + "id: 12345", + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_parse_stream_ignores_retry_field(self): + """Test that 'retry:' lines are ignored without error.""" + lines = [ + "retry: 3000", + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_parse_stream_json_decode_error(self): + """Test that invalid JSON raises KaiStreamError.""" + lines = [ + "data: {invalid json}", + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = self._make_async_iter(lines) + + with pytest.raises(KaiStreamError) as exc_info: + async for _ in parse_sse_stream(response): + pass + + assert "Failed to parse SSE event" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_parse_stream_handles_stream_closed(self): + """Test graceful handling of StreamClosed.""" + + async def _iter_lines(): + yield 'data: {"type":"text","text":"Hello"}' + raise httpx.StreamClosed() + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = _iter_lines + + events = [event async for event in parse_sse_stream(response)] + assert len(events) == 1 + assert events[0].text == "Hello" + + @pytest.mark.asyncio + async def test_parse_stream_handles_remote_protocol_error(self): + """Test KaiStreamError on RemoteProtocolError.""" + + async def _iter_lines(): + yield 'data: {"type":"text","text":"Hello"}' + raise httpx.RemoteProtocolError("Connection reset") + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = _iter_lines + + with pytest.raises(KaiStreamError) as exc_info: + async for _ in parse_sse_stream(response): + pass + + assert "Connection error during streaming" in str(exc_info.value) + + @staticmethod + def _make_async_iter(lines): + """Create an async iterator function from a list of lines.""" + + async def _aiter(): + for line in lines: + yield line + + return _aiter + + +class TestSSEStreamParserConsumeStream: + """Tests for SSEStreamParser.consume_stream method.""" + + @pytest.mark.asyncio + async def test_consume_stream_yields_events(self): + """Test consume_stream yields and processes events.""" + lines = [ + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"text","text":" world"}', + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = TestParseSSEStream._make_async_iter(lines) + + parser = SSEStreamParser() + events = [event async for event in parser.consume_stream(response)] + + assert len(events) == 3 + assert parser.text == "Hello world" + assert parser.finished is True + + @pytest.mark.asyncio + async def test_consume_stream_no_yield(self): + """Test consume_stream with yield_events=False.""" + lines = [ + 'data: {"type":"text","text":"Hello"}', + 'data: {"type":"finish","finishReason":"stop"}', + ] + + response = AsyncMock(spec=httpx.Response) + response.aiter_lines = TestParseSSEStream._make_async_iter(lines) + + parser = SSEStreamParser() + events = [ + event async for event in parser.consume_stream(response, yield_events=False) + ] + + assert len(events) == 0 # No events yielded + assert parser.text == "Hello" # But state was still updated + assert parser.finished is True + +