diff --git a/src/bub/builtin/model_runner.py b/src/bub/builtin/model_runner.py index 0684f8fc..2664b964 100644 --- a/src/bub/builtin/model_runner.py +++ b/src/bub/builtin/model_runner.py @@ -10,6 +10,7 @@ from typing import Any, Literal, cast from any_llm import AnyLLM +from any_llm.providers.openai.base import BaseOpenAIProvider from any_llm.types.completion import ( ChatCompletion, ChatCompletionChunk, @@ -36,6 +37,19 @@ CompletionResult = ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk] +def _stream_usage_options(llm: AnyLLM, *, stream: bool) -> dict[str, Any] | None: + """Make streaming completions report token usage. + + OpenAI-style streaming responses omit the `usage` block unless the request + sets `stream_options.include_usage`; without it every streamed run records + zero tokens (and zero cost). Only OpenAI-compatible providers accept the + field, so gate on the provider base class — anthropic/gemini reject it. + """ + if stream and isinstance(llm, BaseOpenAIProvider): + return {"include_usage": True} + return None + + class ModelRunner: def __init__(self, settings: AgentSettings) -> None: self.settings = settings @@ -61,12 +75,14 @@ async def completion_response( completion_error: Exception | None = None for index, (candidate, llm) in enumerate(clients): try: + streaming = llm.SUPPORTS_COMPLETION_STREAMING return await llm.acompletion( model=candidate.model_id, messages=completion_messages, tools=tool_payloads, max_tokens=self.settings.max_tokens, - stream=llm.SUPPORTS_COMPLETION_STREAMING, + stream=streaming, + stream_options=_stream_usage_options(llm, stream=streaming), ) except Exception as exc: if completion_error is None: diff --git a/tests/test_builtin_model_runner.py b/tests/test_builtin_model_runner.py new file mode 100644 index 00000000..7daebe71 --- /dev/null +++ b/tests/test_builtin_model_runner.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator +from pathlib import Path +from typing import Any + +import pytest +from any_llm.constants import LLMProvider +from any_llm.providers.openai.base import BaseOpenAIProvider +from any_llm.types.completion import ChatCompletionChunk + +from bub.builtin.model_runner import ModelRunner +from bub.builtin.settings import AgentSettings, ModelCandidate +from bub.builtin.tape import Tape +from bub.tape import AsyncTapeStoreAdapter, InMemoryTapeStore, TapeContext + + +class _FakeStreamingOpenAIProvider(BaseOpenAIProvider): + SUPPORTS_COMPLETION_STREAMING = True + + def __init__(self) -> None: + self.completion_kwargs: dict[str, Any] | None = None + + async def acompletion(self, **kwargs: Any) -> AsyncIterator[ChatCompletionChunk]: + self.completion_kwargs = kwargs + include_usage = kwargs.get("stream_options") == {"include_usage": True} + + async def stream() -> AsyncIterator[ChatCompletionChunk]: + yield ChatCompletionChunk.model_validate({ + "id": "chatcmpl_test", + "object": "chat.completion.chunk", + "created": 0, + "model": "gpt-test", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": {"role": "assistant", "content": "done"}, + } + ], + }) + final_chunk: dict[str, Any] = { + "id": "chatcmpl_test", + "object": "chat.completion.chunk", + "created": 0, + "model": "gpt-test", + "choices": [], + } + if include_usage: + final_chunk["usage"] = {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5} + yield ChatCompletionChunk.model_validate(final_chunk) + + return stream() + + +class _FakeOpenAIModelRunner(ModelRunner): + def __init__(self, settings: AgentSettings, llm: _FakeStreamingOpenAIProvider) -> None: + super().__init__(settings) + self._llm = llm + + def iter_llm_clients(self, model: str) -> Iterator[tuple[ModelCandidate, _FakeStreamingOpenAIProvider]]: + yield ModelCandidate(provider=LLMProvider.OPENAI, model_id=model, name=f"openai:{model}"), self._llm + + +@pytest.mark.asyncio +async def test_streaming_openai_usage_is_requested_and_recorded_in_tape(tmp_path: Path) -> None: + store = InMemoryTapeStore() + tape = Tape(tmp_path, AsyncTapeStoreAdapter(store), TapeContext()).scoped("test-tape") + llm = _FakeStreamingOpenAIProvider() + runner = _FakeOpenAIModelRunner( + AgentSettings.model_construct(model="openai:gpt-test", max_tokens=100, model_timeout_seconds=None), + llm, + ) + + await tape.ensure_bootstrap_anchor() + events = [ + event async for event in runner.run(tape=tape, model="gpt-test", tools=[], system_prompt=None, prompt="hello") + ] + + assert llm.completion_kwargs is not None + assert llm.completion_kwargs["stream"] is True + assert llm.completion_kwargs["stream_options"] == {"include_usage": True} + assert [(event.kind, event.data) for event in events] == [ + ("text", {"delta": "done"}), + ("final", {"ok": True, "text": "done"}), + ] + run_events = [ + entry for entry in store.read("test-tape") or [] if entry.kind == "event" and entry.payload.get("name") == "run" + ] + assert len(run_events) == 1 + assert run_events[0].payload["data"]["usage"] == { + "completion_tokens": 2, + "prompt_tokens": 3, + "total_tokens": 5, + }