1010 Dict ,
1111 List ,
1212 Optional ,
13- Tuple ,
1413 Union ,
1514)
1615
17-
1816from fastapi import FastAPI , HTTPException , status
1917from fastapi .middleware .cors import CORSMiddleware
2018from ray import serve
@@ -125,8 +123,11 @@ def _apply_openai_json_format(
125123 The converted strings are concatenated and returned:
126124 data: <response-json1>\n \n data: <response-json2>\n \n ...
127125 """
128-
129- return "" .join (f"data: { response .model_dump_json ()} \n \n " )
126+ if isinstance (response , list ):
127+ return "" .join (f"data: { r .model_dump_json ()} \n \n " for r in response )
128+ if hasattr (response , "model_dump_json" ):
129+ return f"data: { response .model_dump_json ()} \n \n "
130+ raise ValueError (f"Unexpected response type: { type (response )} " )
130131
131132
132133async def _openai_json_wrapper (
@@ -147,29 +148,16 @@ async def _openai_json_wrapper(
147148 Yields:
148149 Concatenated JSON strings that represent CompletionStreamResponse.
149150 """
150- yield _apply_openai_json_format (first_response )
151+ packet = _apply_openai_json_format (first_response )
152+ yield packet
151153
152154 async for response in generator :
153- yield _apply_openai_json_format (response )
155+ packet = _apply_openai_json_format (response )
156+ yield packet
154157
155158 yield "data: [DONE]\n \n "
156159
157160
158- async def _peek_at_openai_json_generator (
159- generator : Union [LLMChatResponse , LLMCompletionsResponse ],
160- ) -> Tuple [
161- Union [ChatCompletionStreamResponse , CompletionStreamResponse , ErrorResponse ],
162- AsyncGenerator [str , None ],
163- ]:
164- """Runs one iteration of the underlying generator
165- and returns the result, alongside the generator itself (with the
166- first iteration still there).
167- """
168- first_response = await generator .__anext__ ()
169-
170- return first_response , _openai_json_wrapper (generator , first_response )
171-
172-
173161class LLMRouter :
174162 def __init__ (
175163 self ,
@@ -347,6 +335,41 @@ async def model_data(self, model: str) -> ModelData:
347335 )
348336 return model_data
349337
338+ async def _process_llm_request (
339+ self , body : Union [CompletionRequest , ChatCompletionRequest ], is_chat : bool
340+ ) -> Response :
341+ NoneStreamingResponseType = (
342+ ChatCompletionResponse if is_chat else CompletionResponse
343+ )
344+ call_method = "chat" if is_chat else "completions"
345+
346+ async with timeout (RAYLLM_ROUTER_HTTP_TIMEOUT ):
347+
348+ gen = self ._get_response (body = body , call_method = call_method )
349+
350+ first_response = await gen .__anext__ ()
351+
352+ # In case of streaming the first response can be batched.
353+ if body .stream and isinstance (first_response , list ):
354+ first_response = first_response [0 ]
355+
356+ if isinstance (first_response , ErrorResponse ):
357+ raise OpenAIHTTPException (
358+ message = first_response .message ,
359+ status_code = first_response .code ,
360+ type = first_response .type ,
361+ )
362+
363+ if isinstance (first_response , NoneStreamingResponseType ):
364+ # Not streaming
365+ return JSONResponse (content = first_response .model_dump ())
366+
367+ openai_stream_generator = _openai_json_wrapper (gen , first_response )
368+
369+ return StreamingResponse (
370+ openai_stream_generator , media_type = "text/event-stream"
371+ )
372+
350373 @fastapi_router_app .post ("/v1/completions" )
351374 async def completions (self , body : CompletionRequest ) -> Response :
352375 """Given a prompt, the model will return one or more predicted completions,
@@ -355,28 +378,7 @@ async def completions(self, body: CompletionRequest) -> Response:
355378 Returns:
356379 A response object with completions.
357380 """
358- async with timeout (RAYLLM_ROUTER_HTTP_TIMEOUT ):
359- results = self ._get_response (body = body , call_method = "completions" )
360- if body .stream :
361- first_response , wrapper = await _peek_at_openai_json_generator (results )
362- if isinstance (first_response , ErrorResponse ):
363- raise OpenAIHTTPException (
364- message = first_response .message ,
365- status_code = first_response .code ,
366- type = first_response .type ,
367- )
368- return StreamingResponse (wrapper , media_type = "text/event-stream" )
369-
370- result = await results .__anext__ ()
371- if isinstance (result , ErrorResponse ):
372- raise OpenAIHTTPException (
373- message = result .message ,
374- status_code = result .code ,
375- type = result .type ,
376- )
377-
378- if isinstance (result , CompletionResponse ):
379- return JSONResponse (content = result .model_dump ())
381+ return await self ._process_llm_request (body , is_chat = False )
380382
381383 @fastapi_router_app .post ("/v1/chat/completions" )
382384 async def chat (self , body : ChatCompletionRequest ) -> Response :
@@ -387,28 +389,7 @@ async def chat(self, body: ChatCompletionRequest) -> Response:
387389 A response object with completions.
388390 """
389391
390- async with timeout (RAYLLM_ROUTER_HTTP_TIMEOUT ):
391- results = self ._get_response (body = body , call_method = "chat" )
392- if body .stream :
393- first_response , wrapper = await _peek_at_openai_json_generator (results )
394- if isinstance (first_response , ErrorResponse ):
395- raise OpenAIHTTPException (
396- message = first_response .message ,
397- status_code = first_response .code ,
398- type = first_response .type ,
399- )
400- return StreamingResponse (wrapper , media_type = "text/event-stream" )
401-
402- result = await results .__anext__ ()
403- if isinstance (result , ErrorResponse ):
404- raise OpenAIHTTPException (
405- message = result .message ,
406- status_code = result .code ,
407- type = result .type ,
408- )
409-
410- if isinstance (result , ChatCompletionResponse ):
411- return JSONResponse (content = result .model_dump ())
392+ return await self ._process_llm_request (body , is_chat = True )
412393
413394 @classmethod
414395 def as_deployment (
0 commit comments