Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 20 additions & 52 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@
_T = TypeVar("_T")


def _raise_if_stream_error(data: object, response: httpx.Response) -> None:
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=response.request,
body=data["error"],
)


class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""

Expand Down Expand Up @@ -64,36 +80,12 @@ def __stream__(self) -> Iterator[_T]:
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
_raise_if_stream_error(data, response)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
_raise_if_stream_error(data, response)

yield process_data(data=data, cast_to=cast_to, response=response)

Expand Down Expand Up @@ -167,36 +159,12 @@ async def __stream__(self) -> AsyncIterator[_T]:
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
_raise_if_stream_error(data, response)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
_raise_if_stream_error(data, response)

yield process_data(data=data, cast_to=cast_to, response=response)

Expand Down
52 changes: 52 additions & 0 deletions tests/test_streaming_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from typing import Iterator, AsyncIterator

import httpx
import pytest

from openai import OpenAI, AsyncOpenAI
from openai._exceptions import APIError
from openai._streaming import Stream, AsyncStream


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_thread_event_error_raises(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
def body() -> Iterator[bytes]:
yield b"event: thread.error\n"
yield b'data: {"error": {"message": "boom"}}\n'
yield b"\n"

iterator = make_stream_iterator(content=body(), sync=sync, client=client, async_client=async_client)

with pytest.raises(APIError, match="boom"):
await iter_next(iterator)


async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
for chunk in iter:
yield chunk


async def iter_next(iter: Iterator[object] | AsyncIterator[object]) -> object:
if isinstance(iter, AsyncIterator):
return await iter.__anext__()

return next(iter)


def make_stream_iterator(
content: Iterator[bytes],
*,
sync: bool,
client: OpenAI,
async_client: AsyncOpenAI,
) -> Iterator[object] | AsyncIterator[object]:
request = httpx.Request("GET", "http://test")
if sync:
response = httpx.Response(200, request=request, content=content)
return iter(Stream(cast_to=object, client=client, response=response))

response = httpx.Response(200, request=request, content=to_aiter(content))
return AsyncStream(cast_to=object, client=async_client, response=response).__aiter__()