diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index 45c13cc11d..1be7195026 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -63,6 +63,9 @@ def __stream__(self) -> Iterator[_T]: if sse.data.startswith("[DONE]"): break + if not sse.data: + continue + # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data if sse.event and sse.event.startswith("thread."): data = sse.json() @@ -173,6 +176,9 @@ async def __stream__(self) -> AsyncIterator[_T]: if sse.data.startswith("[DONE]"): break + if not sse.data: + continue + # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data if sse.event and sse.event.startswith("thread."): data = sse.json() diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 04f8e51abd..16571bc6ba 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -216,6 +216,44 @@ def body() -> Iterator[bytes]: assert sse.json() == {"content": "известни"} +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_stream_skips_meta_only_retry_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None: + def body() -> Iterator[bytes]: + yield b"retry: 3000\n\n" + yield b'data: {"foo":true}\n\n' + + results = await collect_stream(body(), sync=sync, client=client, async_client=async_client) + assert len(results) == 1 + assert results[0] == {"foo": True} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_stream_skips_event_without_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n\n" + yield b'data: {"bar":1}\n\n' + yield b"data: [DONE]\n\n" + + results = await collect_stream(body(), sync=sync, client=client, async_client=async_client) + assert len(results) == 1 + assert results[0] == {"bar": 1} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_stream_skips_thread_event_without_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None: + def body() -> Iterator[bytes]: + yield b"event: thread.created\n\n" + yield b'event: thread.created\ndata: {"id":"t1"}\n\n' + yield b"data: [DONE]\n\n" + + results = await collect_stream(body(), sync=sync, client=client, async_client=async_client) + assert len(results) == 1 + assert results[0] == {"data": {"id": "t1"}, "event": "thread.created"} + + async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: for chunk in iter: yield chunk @@ -246,3 +284,22 @@ def make_event_iterator( return AsyncStream( cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) )._iter_events() + + +async def collect_stream( + content: Iterator[bytes], + *, + sync: bool, + client: OpenAI, + async_client: AsyncOpenAI, +) -> list[object]: + results: list[object] = [] + if sync: + stream = Stream(cast_to=object, client=client, response=httpx.Response(200, content=content)) + for item in stream: + results.append(item) + else: + stream = AsyncStream(cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))) + async for item in stream: + results.append(item) + return results