From dda553852d4372ead325867ccdd836e917d15cfe Mon Sep 17 00:00:00 2001 From: lif <1835304752@qq.com> Date: Sat, 20 Dec 2025 23:40:22 +0800 Subject: [PATCH] fix(streaming): raise on thread error events --- src/openai/_streaming.py | 72 ++++++++++------------------------ tests/test_streaming_errors.py | 52 ++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 52 deletions(-) create mode 100644 tests/test_streaming_errors.py diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index 61a742668a..373e6fb9a0 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -19,6 +19,22 @@ _T = TypeVar("_T") +def _raise_if_stream_error(data: object, response: httpx.Response) -> None: + if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIError( + message=message, + request=response.request, + body=data["error"], + ) + + class Stream(Generic[_T]): """Provides the core interface to iterate over a synchronous stream response.""" @@ -64,36 +80,12 @@ def __stream__(self) -> Iterator[_T]: if sse.event and sse.event.startswith("thread."): data = sse.json() - if sse.event == "error" and is_mapping(data) and data.get("error"): - message = None - error = data.get("error") - if is_mapping(error): - message = error.get("message") - if not message or not isinstance(message, str): - message = "An error occurred during streaming" - - raise APIError( - message=message, - request=self.response.request, - body=data["error"], - ) + _raise_if_stream_error(data, response) yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) else: data = sse.json() - if is_mapping(data) and data.get("error"): - message = None - error = data.get("error") - if is_mapping(error): - message = error.get("message") - if not message or not isinstance(message, str): - message = "An error occurred during streaming" - - raise APIError( - message=message, - request=self.response.request, - body=data["error"], - ) + _raise_if_stream_error(data, response) yield process_data(data=data, cast_to=cast_to, response=response) @@ -167,36 +159,12 @@ async def __stream__(self) -> AsyncIterator[_T]: if sse.event and sse.event.startswith("thread."): data = sse.json() - if sse.event == "error" and is_mapping(data) and data.get("error"): - message = None - error = data.get("error") - if is_mapping(error): - message = error.get("message") - if not message or not isinstance(message, str): - message = "An error occurred during streaming" - - raise APIError( - message=message, - request=self.response.request, - body=data["error"], - ) + _raise_if_stream_error(data, response) yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) else: data = sse.json() - if is_mapping(data) and data.get("error"): - message = None - error = data.get("error") - if is_mapping(error): - message = error.get("message") - if not message or not isinstance(message, str): - message = "An error occurred during streaming" - - raise APIError( - message=message, - request=self.response.request, - body=data["error"], - ) + _raise_if_stream_error(data, response) yield process_data(data=data, cast_to=cast_to, response=response) diff --git a/tests/test_streaming_errors.py b/tests/test_streaming_errors.py new file mode 100644 index 0000000000..ed40cb653d --- /dev/null +++ b/tests/test_streaming_errors.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Iterator, AsyncIterator + +import httpx +import pytest + +from openai import OpenAI, AsyncOpenAI +from openai._exceptions import APIError +from openai._streaming import Stream, AsyncStream + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_thread_event_error_raises(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None: + def body() -> Iterator[bytes]: + yield b"event: thread.error\n" + yield b'data: {"error": {"message": "boom"}}\n' + yield b"\n" + + iterator = make_stream_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + with pytest.raises(APIError, match="boom"): + await iter_next(iterator) + + +async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: + for chunk in iter: + yield chunk + + +async def iter_next(iter: Iterator[object] | AsyncIterator[object]) -> object: + if isinstance(iter, AsyncIterator): + return await iter.__anext__() + + return next(iter) + + +def make_stream_iterator( + content: Iterator[bytes], + *, + sync: bool, + client: OpenAI, + async_client: AsyncOpenAI, +) -> Iterator[object] | AsyncIterator[object]: + request = httpx.Request("GET", "http://test") + if sync: + response = httpx.Response(200, request=request, content=content) + return iter(Stream(cast_to=object, client=client, response=response)) + + response = httpx.Response(200, request=request, content=to_aiter(content)) + return AsyncStream(cast_to=object, client=async_client, response=response).__aiter__()