Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __stream__(self) -> Iterator[_T]:
if sse.data.startswith("[DONE]"):
break

if not sse.data:
continue

# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()
Expand Down Expand Up @@ -173,6 +176,9 @@ async def __stream__(self) -> AsyncIterator[_T]:
if sse.data.startswith("[DONE]"):
break

if not sse.data:
continue

# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()
Expand Down
57 changes: 57 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,44 @@ def body() -> Iterator[bytes]:
assert sse.json() == {"content": "известни"}


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_stream_skips_meta_only_retry_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
def body() -> Iterator[bytes]:
yield b"retry: 3000\n\n"
yield b'data: {"foo":true}\n\n'

results = await collect_stream(body(), sync=sync, client=client, async_client=async_client)
assert len(results) == 1
assert results[0] == {"foo": True}


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_stream_skips_event_without_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
def body() -> Iterator[bytes]:
yield b"event: ping\n\n"
yield b'data: {"bar":1}\n\n'
yield b"data: [DONE]\n\n"

results = await collect_stream(body(), sync=sync, client=client, async_client=async_client)
assert len(results) == 1
assert results[0] == {"bar": 1}


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_stream_skips_thread_event_without_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
def body() -> Iterator[bytes]:
yield b"event: thread.created\n\n"
yield b'event: thread.created\ndata: {"id":"t1"}\n\n'
yield b"data: [DONE]\n\n"

results = await collect_stream(body(), sync=sync, client=client, async_client=async_client)
assert len(results) == 1
assert results[0] == {"data": {"id": "t1"}, "event": "thread.created"}


async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
for chunk in iter:
yield chunk
Expand Down Expand Up @@ -246,3 +284,22 @@ def make_event_iterator(
return AsyncStream(
cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))
)._iter_events()


async def collect_stream(
content: Iterator[bytes],
*,
sync: bool,
client: OpenAI,
async_client: AsyncOpenAI,
) -> list[object]:
results: list[object] = []
if sync:
stream = Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))
for item in stream:
results.append(item)
else:
stream = AsyncStream(cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)))
async for item in stream:
results.append(item)
return results