Skip to content

Commit e39fd8e

Browse files
authored
[Serve.llm] Mitigate the serve.llm streaming overhead by properly batching stream chunks (#52766)
Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: kouroshhakha <[email protected]>
1 parent 0617445 commit e39fd8e

File tree

8 files changed

+995
-568
lines changed

8 files changed

+995
-568
lines changed

python/ray/llm/_internal/serve/deployments/llm/llm_server.py

Lines changed: 293 additions & 254 deletions
Large diffs are not rendered by default.

python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,10 @@
4949
from ray.llm._internal.serve.configs.constants import (
5050
RAYLLM_ENABLE_REQUEST_PROMPT_LOGS,
5151
RAYLLM_GUIDED_DECODING_BACKEND,
52-
MODEL_RESPONSE_BATCH_TIMEOUT_MS,
5352
MIN_NUM_TOPLOGPROBS_ALLOWED,
5453
MAX_NUM_TOPLOGPROBS_ALLOWED,
5554
)
5655
from ray.llm._internal.utils import try_import
57-
from ray.llm._internal.serve.deployments.utils.batcher import LLMRawResponsesBatcher
5856

5957
from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine
6058

@@ -519,30 +517,7 @@ async def prepare_request(
519517
vllm_request = VLLMGenerationRequest(**request_params)
520518
return vllm_request
521519

522-
def _get_batch_interval_ms(self, stream: bool = True) -> int:
523-
"""Calculate the batching interval for responses."""
524-
stream_batching_interval_ms = self.llm_config.experimental_configs.get(
525-
"stream_batching_interval_ms"
526-
)
527-
if stream_batching_interval_ms is None:
528-
stream_batching_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS
529-
return stream_batching_interval_ms if stream else None
530-
531520
async def generate(
532-
self,
533-
request: GenerationRequest,
534-
) -> AsyncGenerator[LLMRawResponse, None]:
535-
# TODO (genesu): Responses batching logics should be common to all
536-
# engines and belongs to the LLMServer level instead of the engine
537-
# level here. Refactor the entire batching logics up.
538-
response_stream = LLMRawResponsesBatcher(
539-
self._generate(request),
540-
interval_ms=self._get_batch_interval_ms(request.stream),
541-
)
542-
async for response in response_stream.stream():
543-
yield response
544-
545-
async def _generate(
546521
self, request: GenerationRequest
547522
) -> AsyncGenerator[LLMRawResponse, None]:
548523
"""Generate an LLMRawResponse stream

python/ray/llm/_internal/serve/deployments/routers/router.py

Lines changed: 46 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@
1010
Dict,
1111
List,
1212
Optional,
13-
Tuple,
1413
Union,
1514
)
1615

17-
1816
from fastapi import FastAPI, HTTPException, status
1917
from fastapi.middleware.cors import CORSMiddleware
2018
from ray import serve
@@ -125,8 +123,11 @@ def _apply_openai_json_format(
125123
The converted strings are concatenated and returned:
126124
data: <response-json1>\n\ndata: <response-json2>\n\n...
127125
"""
128-
129-
return "".join(f"data: {response.model_dump_json()}\n\n")
126+
if isinstance(response, list):
127+
return "".join(f"data: {r.model_dump_json()}\n\n" for r in response)
128+
if hasattr(response, "model_dump_json"):
129+
return f"data: {response.model_dump_json()}\n\n"
130+
raise ValueError(f"Unexpected response type: {type(response)}")
130131

131132

132133
async def _openai_json_wrapper(
@@ -147,29 +148,16 @@ async def _openai_json_wrapper(
147148
Yields:
148149
Concatenated JSON strings that represent CompletionStreamResponse.
149150
"""
150-
yield _apply_openai_json_format(first_response)
151+
packet = _apply_openai_json_format(first_response)
152+
yield packet
151153

152154
async for response in generator:
153-
yield _apply_openai_json_format(response)
155+
packet = _apply_openai_json_format(response)
156+
yield packet
154157

155158
yield "data: [DONE]\n\n"
156159

157160

158-
async def _peek_at_openai_json_generator(
159-
generator: Union[LLMChatResponse, LLMCompletionsResponse],
160-
) -> Tuple[
161-
Union[ChatCompletionStreamResponse, CompletionStreamResponse, ErrorResponse],
162-
AsyncGenerator[str, None],
163-
]:
164-
"""Runs one iteration of the underlying generator
165-
and returns the result, alongside the generator itself (with the
166-
first iteration still there).
167-
"""
168-
first_response = await generator.__anext__()
169-
170-
return first_response, _openai_json_wrapper(generator, first_response)
171-
172-
173161
class LLMRouter:
174162
def __init__(
175163
self,
@@ -347,6 +335,41 @@ async def model_data(self, model: str) -> ModelData:
347335
)
348336
return model_data
349337

338+
async def _process_llm_request(
339+
self, body: Union[CompletionRequest, ChatCompletionRequest], is_chat: bool
340+
) -> Response:
341+
NoneStreamingResponseType = (
342+
ChatCompletionResponse if is_chat else CompletionResponse
343+
)
344+
call_method = "chat" if is_chat else "completions"
345+
346+
async with timeout(RAYLLM_ROUTER_HTTP_TIMEOUT):
347+
348+
gen = self._get_response(body=body, call_method=call_method)
349+
350+
first_response = await gen.__anext__()
351+
352+
# In case of streaming the first response can be batched.
353+
if body.stream and isinstance(first_response, list):
354+
first_response = first_response[0]
355+
356+
if isinstance(first_response, ErrorResponse):
357+
raise OpenAIHTTPException(
358+
message=first_response.message,
359+
status_code=first_response.code,
360+
type=first_response.type,
361+
)
362+
363+
if isinstance(first_response, NoneStreamingResponseType):
364+
# Not streaming
365+
return JSONResponse(content=first_response.model_dump())
366+
367+
openai_stream_generator = _openai_json_wrapper(gen, first_response)
368+
369+
return StreamingResponse(
370+
openai_stream_generator, media_type="text/event-stream"
371+
)
372+
350373
@fastapi_router_app.post("/v1/completions")
351374
async def completions(self, body: CompletionRequest) -> Response:
352375
"""Given a prompt, the model will return one or more predicted completions,
@@ -355,28 +378,7 @@ async def completions(self, body: CompletionRequest) -> Response:
355378
Returns:
356379
A response object with completions.
357380
"""
358-
async with timeout(RAYLLM_ROUTER_HTTP_TIMEOUT):
359-
results = self._get_response(body=body, call_method="completions")
360-
if body.stream:
361-
first_response, wrapper = await _peek_at_openai_json_generator(results)
362-
if isinstance(first_response, ErrorResponse):
363-
raise OpenAIHTTPException(
364-
message=first_response.message,
365-
status_code=first_response.code,
366-
type=first_response.type,
367-
)
368-
return StreamingResponse(wrapper, media_type="text/event-stream")
369-
370-
result = await results.__anext__()
371-
if isinstance(result, ErrorResponse):
372-
raise OpenAIHTTPException(
373-
message=result.message,
374-
status_code=result.code,
375-
type=result.type,
376-
)
377-
378-
if isinstance(result, CompletionResponse):
379-
return JSONResponse(content=result.model_dump())
381+
return await self._process_llm_request(body, is_chat=False)
380382

381383
@fastapi_router_app.post("/v1/chat/completions")
382384
async def chat(self, body: ChatCompletionRequest) -> Response:
@@ -387,28 +389,7 @@ async def chat(self, body: ChatCompletionRequest) -> Response:
387389
A response object with completions.
388390
"""
389391

390-
async with timeout(RAYLLM_ROUTER_HTTP_TIMEOUT):
391-
results = self._get_response(body=body, call_method="chat")
392-
if body.stream:
393-
first_response, wrapper = await _peek_at_openai_json_generator(results)
394-
if isinstance(first_response, ErrorResponse):
395-
raise OpenAIHTTPException(
396-
message=first_response.message,
397-
status_code=first_response.code,
398-
type=first_response.type,
399-
)
400-
return StreamingResponse(wrapper, media_type="text/event-stream")
401-
402-
result = await results.__anext__()
403-
if isinstance(result, ErrorResponse):
404-
raise OpenAIHTTPException(
405-
message=result.message,
406-
status_code=result.code,
407-
type=result.type,
408-
)
409-
410-
if isinstance(result, ChatCompletionResponse):
411-
return JSONResponse(content=result.model_dump())
392+
return await self._process_llm_request(body, is_chat=True)
412393

413394
@classmethod
414395
def as_deployment(

python/ray/llm/_internal/serve/deployments/utils/batcher.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import AsyncGenerator, Optional
2+
from typing import AsyncGenerator, Optional, Iterable, List, TypeVar, Generic
33

44

55
from ray.llm._internal.serve.observability.logging import get_logger
@@ -15,8 +15,10 @@
1515

1616
logger = get_logger(__name__)
1717

18+
T = TypeVar("T")
1819

19-
class LLMRawResponsesBatcher:
20+
21+
class Batcher(Generic[T]):
2022
"""This class batches multiple LLMRawResponses from a generator into a
2123
single response, at some time interval.
2224
@@ -30,7 +32,7 @@ class LLMRawResponsesBatcher:
3032

3133
def __init__(
3234
self,
33-
generator: AsyncGenerator[LLMRawResponse, None],
35+
generator: AsyncGenerator[T, None],
3436
interval_ms: Optional[float] = MODEL_RESPONSE_BATCH_TIMEOUT_MS,
3537
):
3638
self.generator = generator
@@ -46,7 +48,10 @@ def __init__(
4648
# We are okay with this task getting cancelled (to propagate cancellations)
4749
self.read_task = asyncio.create_task(self.read())
4850

49-
async def stream(self) -> AsyncGenerator[BatchedLLMRawResponse, None]:
51+
def _merge_results(self, results: List[T]) -> Iterable[T]:
52+
return results
53+
54+
async def stream(self) -> AsyncGenerator[Iterable[T], None]:
5055
"""Drain from the queue every interval_ms and yield the merged results"""
5156
try:
5257
while True:
@@ -67,7 +72,7 @@ async def stream(self) -> AsyncGenerator[BatchedLLMRawResponse, None]:
6772

6873
# If there are results, merge and yield them
6974
if results:
70-
output: BatchedLLMRawResponse = BatchedLLMRawResponse.merge_stream(*results) # type: ignore
75+
output = self._merge_results(results)
7176
yield output
7277

7378
# If the read task is done, exit the stream task
@@ -101,3 +106,15 @@ def drain_queue(self):
101106
except asyncio.QueueEmpty:
102107
pass
103108
return results
109+
110+
111+
class LLMRawResponseBatcher(Batcher):
112+
"""This class batches multiple LLMRawResponses into a single BatchedLLMRawResponse."""
113+
114+
def _merge_results(self, results: List[LLMRawResponse]) -> BatchedLLMRawResponse:
115+
output: BatchedLLMRawResponse = BatchedLLMRawResponse.merge_stream(*results) # type: ignore
116+
return output
117+
118+
119+
class OpenAIResponseBatcher(Batcher):
120+
"""This class batches multiple OpenAI responses into a single OpenAI response."""

python/ray/llm/_internal/serve/deployments/utils/server_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ def get_response_for_error(
125125

126126
def get_serve_request_id() -> str:
127127
"""Get request id from serve request context."""
128-
return serve.context._serve_request_context.get().request_id
128+
context = serve.context._serve_request_context.get()
129+
if context is not None:
130+
return context.request_id
131+
return ""
129132

130133

131134
def get_model_request_id(model: str):

0 commit comments

Comments
 (0)