-
Notifications
You must be signed in to change notification settings - Fork 60
Removing kernal messaging in aqua #1304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
23a87a9
e7c18d3
ab7e984
a534b62
30f8e47
f45350a
6e46a8c
3c4895b
df27ccf
39cc70c
ead77e7
7c1d125
388289d
e43d541
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,8 +7,8 @@ | |
|
|
||
| from tornado.web import HTTPError | ||
|
|
||
| from ads.aqua.app import logger | ||
| from ads.aqua.client.client import Client, ExtendedRequestError | ||
| from ads.aqua.client.openai_client import OpenAI | ||
| from ads.aqua.common.decorator import handle_exceptions | ||
| from ads.aqua.common.enums import PredictEndpoints | ||
| from ads.aqua.extension.base_handler import AquaAPIhandler | ||
|
|
@@ -221,11 +221,49 @@ def list_shapes(self): | |
|
|
||
|
|
||
| class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): | ||
|
|
||
| def _extract_text_from_choice(self, choice): | ||
| # choice may be a dict or an object | ||
| if isinstance(choice, dict): | ||
| # streaming chunk: {"delta": {"content": "..."}} | ||
| delta = choice.get("delta") | ||
| if isinstance(delta, dict): | ||
| return delta.get("content") or delta.get("text") or None | ||
| # non-streaming: {"message": {"content": "..."}} | ||
| msg = choice.get("message") | ||
| if isinstance(msg, dict): | ||
| return msg.get("content") or msg.get("text") | ||
| # fallback top-level fields | ||
| return choice.get("text") or choice.get("content") | ||
| # object-like choice | ||
| delta = getattr(choice, "delta", None) | ||
| if delta is not None: | ||
| return getattr(delta, "content", None) or getattr(delta, "text", None) | ||
| msg = getattr(choice, "message", None) | ||
| if msg is not None: | ||
| if isinstance(msg, str): | ||
| return msg | ||
| return getattr(msg, "content", None) or getattr(msg, "text", None) | ||
| return getattr(choice, "text", None) or getattr(choice, "content", None) | ||
|
|
||
| def _extract_text_from_chunk(self, chunk): | ||
| if chunk : | ||
VipulMascarenhas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if isinstance(chunk, dict): | ||
| choices = chunk.get("choices") or [] | ||
| if choices: | ||
| return self._extract_text_from_choice(choices[0]) | ||
| # fallback top-level | ||
| return chunk.get("text") or chunk.get("content") | ||
| # object-like chunk | ||
| choices = getattr(chunk, "choices", None) | ||
| if choices: | ||
| return self._extract_text_from_choice(choices[0]) | ||
| return getattr(chunk, "text", None) or getattr(chunk, "content", None) | ||
|
|
||
| def _get_model_deployment_response( | ||
| self, | ||
| model_deployment_id: str, | ||
| payload: dict, | ||
| route_override_header: Optional[str], | ||
| payload: dict | ||
| ): | ||
| """ | ||
| Returns the model deployment inference response in a streaming fashion. | ||
|
|
@@ -272,49 +310,160 @@ def _get_model_deployment_response( | |
| """ | ||
|
|
||
| model_deployment = AquaDeploymentApp().get(model_deployment_id) | ||
| endpoint = model_deployment.endpoint + "/predictWithResponseStream" | ||
| endpoint_type = model_deployment.environment_variables.get( | ||
| "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT | ||
| ) | ||
| aqua_client = Client(endpoint=endpoint) | ||
|
|
||
| if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( | ||
| endpoint_type, | ||
| route_override_header, | ||
| ): | ||
| endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1" | ||
| endpoint_type = payload["endpoint_type"] | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| aqua_client = OpenAI(base_url=endpoint) | ||
|
|
||
| allowed = { | ||
| "max_tokens", | ||
| "temperature", | ||
| "top_p", | ||
| "stop", | ||
| "n", | ||
| "presence_penalty", | ||
| "frequency_penalty", | ||
| "logprobs", | ||
| "user", | ||
| "echo", | ||
| } | ||
| responses_allowed = { | ||
| "temperature", "top_p" | ||
| } | ||
|
|
||
| # normalize and filter | ||
| if payload.get("stop") == []: | ||
| payload["stop"] = None | ||
|
|
||
| encoded_image = "NA" | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if encoded_image in payload : | ||
| encoded_image = payload["encoded_image"] | ||
|
|
||
| model = payload.pop("model") | ||
| filtered = {k: v for k, v in payload.items() if k in allowed} | ||
| responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed} | ||
|
|
||
| if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA": | ||
| try: | ||
| for chunk in aqua_client.chat( | ||
| messages=payload.pop("messages"), | ||
| payload=payload, | ||
| stream=True, | ||
| ): | ||
| try: | ||
| if "text" in chunk["choices"][0]: | ||
| yield chunk["choices"][0]["text"] | ||
| elif "content" in chunk["choices"][0]["delta"]: | ||
| yield chunk["choices"][0]["delta"]["content"] | ||
| except Exception as e: | ||
| logger.debug( | ||
| f"Exception occurred while parsing streaming response: {e}" | ||
| ) | ||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [{"role": "user", "content": payload["prompt"]}], | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "stream": True, | ||
| **filtered | ||
| } | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| stream = aqua_client.chat.completions.create(**api_kwargs) | ||
|
|
||
| for chunk in stream: | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
|
|
||
| elif ( | ||
| endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT | ||
| and encoded_image != "NA" | ||
| ): | ||
| file_type = payload.pop("file_type") | ||
| if file_type.startswith("image"): | ||
|
||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "text", "text": payload["prompt"]}, | ||
VipulMascarenhas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
| "type": "image_url", | ||
| "image_url": {"url": f"{self.encoded_image}"}, | ||
|
||
| }, | ||
| ], | ||
| } | ||
| ], | ||
| "stream": True, | ||
| **filtered | ||
| } | ||
|
|
||
| # Add chat_template for image-based chat completions | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| response = aqua_client.chat.completions.create(**api_kwargs) | ||
|
|
||
| elif self.file_type.startswith("audio"): | ||
|
||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "text", "text": payload["prompt"]}, | ||
| { | ||
| "type": "audio_url", | ||
| "audio_url": {"url": f"{self.encoded_image}"}, | ||
|
||
| }, | ||
| ], | ||
| } | ||
| ], | ||
| "stream": True, | ||
| **filtered | ||
| } | ||
|
|
||
| # Add chat_template for audio-based chat completions | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| response = aqua_client.chat.completions.create(**api_kwargs) | ||
| try: | ||
| for chunk in response: | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece: | ||
| print(piece, end="", flush=True) | ||
|
||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
| elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: | ||
| try: | ||
| for chunk in aqua_client.generate( | ||
| prompt=payload.pop("prompt"), | ||
| payload=payload, | ||
| stream=True, | ||
| for chunk in aqua_client.completions.create( | ||
| prompt=payload["prompt"], stream=True, model=model, **filtered | ||
| ): | ||
| try: | ||
| yield chunk["choices"][0]["text"] | ||
| except Exception as e: | ||
| logger.debug( | ||
| f"Exception occurred while parsing streaming response: {e}" | ||
| ) | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
|
|
||
| elif endpoint_type == PredictEndpoints.RESPONSES: | ||
| api_kwargs = { | ||
| "model": model, | ||
| "input": payload["prompt"], | ||
| "stream": True | ||
| } | ||
|
|
||
| if "temperature" in responses_filtered: | ||
| api_kwargs["temperature"] = responses_filtered["temperature"] | ||
| if "top_p" in responses_filtered: | ||
| api_kwargs["top_p"] = responses_filtered["top_p"] | ||
|
|
||
| response = aqua_client.responses.create(**api_kwargs) | ||
| try: | ||
| for chunk in response: | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -340,19 +489,20 @@ def post(self, model_deployment_id): | |
| prompt = input_data.get("prompt") | ||
| messages = input_data.get("messages") | ||
|
|
||
|
|
||
| if not prompt and not messages: | ||
| raise HTTPError( | ||
| 400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages") | ||
| ) | ||
| if not input_data.get("model"): | ||
| raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) | ||
| route_override_header = self.request.headers.get("route", None) | ||
| self.set_header("Content-Type", "text/event-stream") | ||
| response_gen = self._get_model_deployment_response( | ||
| model_deployment_id, input_data, route_override_header | ||
| model_deployment_id, input_data | ||
| ) | ||
| try: | ||
| for chunk in response_gen: | ||
| print(chunk) | ||
|
||
| self.write(chunk) | ||
| self.flush() | ||
| self.finish() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add docstrings, use type hinting