-
Notifications
You must be signed in to change notification settings - Fork 60
Removing kernal messaging in aqua #1304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
23a87a9
deployment inference using openAI client
agrimk e7c18d3
Merge branch 'main' of github.com:oracle/accelerated-data-science int…
agrimk ab7e984
Merge branch 'main' of github.com:oracle/accelerated-data-science int…
agrimk a534b62
stream inference endpoint
agrimk 30f8e47
Merge branch 'main' into removing_kernal_messaging_in_aqua
agrimk f45350a
unit test fixes
agrimk 6e46a8c
Merge branch 'removing_kernal_messaging_in_aqua' of github.com:oracle…
agrimk 3c4895b
Merge branch 'main' of github.com:oracle/accelerated-data-science int…
agrimk df27ccf
added test cases and PR review comments
agrimk 39cc70c
fixing test cases
agrimk ead77e7
fixed handling of encoded_image
agrimk 7c1d125
removing an ocid
agrimk 388289d
running precommit hook
agrimk e43d541
Merge branch 'main' into removing_kernal_messaging_in_aqua
agrimk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,14 +7,15 @@ | |
|
|
||
| from tornado.web import HTTPError | ||
|
|
||
| from ads.aqua.app import logger | ||
| from ads.aqua.client.client import Client, ExtendedRequestError | ||
| from ads.aqua.client.openai_client import OpenAI | ||
| from ads.aqua.common.decorator import handle_exceptions | ||
| from ads.aqua.common.enums import PredictEndpoints | ||
| from ads.aqua.extension.base_handler import AquaAPIhandler | ||
| from ads.aqua.extension.errors import Errors | ||
| from ads.aqua.modeldeployment import AquaDeploymentApp | ||
| from ads.config import COMPARTMENT_OCID | ||
| from ads.aqua import logger | ||
|
|
||
|
|
||
| class AquaDeploymentHandler(AquaAPIhandler): | ||
|
|
@@ -221,11 +222,102 @@ def list_shapes(self): | |
|
|
||
|
|
||
| class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): | ||
|
|
||
| def _extract_text_from_choice(self, choice: dict) -> str: | ||
| """ | ||
| Extract text content from a single choice structure. | ||
|
|
||
| Handles both dictionary-based API responses and object-based SDK responses. | ||
| For dict choices, it checks delta-based streaming fields, message-based | ||
| non-streaming fields, and finally top-level text/content keys. | ||
| For object choices, it inspects `.delta`, `.message`, and top-level | ||
| `.text` or `.content` attributes. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| choice : dict | ||
| A choice entry from a model response. It may be: | ||
| - A dict originating from a JSON API response (streaming or non-streaming). | ||
| - An SDK-style object with attributes such as `delta`, `message`, | ||
| `text`, or `content`. | ||
|
|
||
| For dicts, the method checks: | ||
| • delta → content/text | ||
| • message → content/text | ||
| • top-level → text/content | ||
|
|
||
| For objects, the method checks the same fields via attributes. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | None: | ||
| The extracted text if present; otherwise None. | ||
| """ | ||
| # choice may be a dict or an object | ||
| if isinstance(choice, dict): | ||
| # streaming chunk: {"delta": {"content": "..."}} | ||
| delta = choice.get("delta") | ||
| if isinstance(delta, dict): | ||
| return delta.get("content") or delta.get("text") or None | ||
| # non-streaming: {"message": {"content": "..."}} | ||
| msg = choice.get("message") | ||
| if isinstance(msg, dict): | ||
| return msg.get("content") or msg.get("text") | ||
| # fallback top-level fields | ||
| return choice.get("text") or choice.get("content") | ||
| # object-like choice | ||
| delta = getattr(choice, "delta", None) | ||
| if delta is not None: | ||
| return getattr(delta, "content", None) or getattr(delta, "text", None) | ||
| msg = getattr(choice, "message", None) | ||
| if msg is not None: | ||
| if isinstance(msg, str): | ||
| return msg | ||
| return getattr(msg, "content", None) or getattr(msg, "text", None) | ||
| return getattr(choice, "text", None) or getattr(choice, "content", None) | ||
|
|
||
| def _extract_text_from_chunk(self, chunk: dict) -> str : | ||
| """ | ||
| Extract text content from a model response chunk. | ||
|
|
||
| Supports both dict-form chunks (streaming or non-streaming) and SDK-style | ||
| object chunks. When choices are present, extraction is delegated to | ||
| `_extract_text_from_choice`. If no choices exist, top-level text/content | ||
| fields or attributes are used. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| chunk : dict | ||
| A chunk returned from a model stream or full response. It may be: | ||
| - A dict containing a `choices` list or top-level text/content fields. | ||
| - An SDK-style object with a `choices` attribute or top-level | ||
| `text`/`content` attributes. | ||
|
|
||
| If `choices` is present, the method extracts text from the first | ||
| choice using `_extract_text_from_choice`. Otherwise, it falls back | ||
| to top-level text/content. | ||
| Returns | ||
| ------- | ||
| str | ||
| The extracted text if present; otherwise None. | ||
| """ | ||
| if chunk : | ||
| if isinstance(chunk, dict): | ||
| choices = chunk.get("choices") or [] | ||
| if choices: | ||
| return self._extract_text_from_choice(choices[0]) | ||
| # fallback top-level | ||
| return chunk.get("text") or chunk.get("content") | ||
| # object-like chunk | ||
| choices = getattr(chunk, "choices", None) | ||
| if choices: | ||
| return self._extract_text_from_choice(choices[0]) | ||
| return getattr(chunk, "text", None) or getattr(chunk, "content", None) | ||
|
|
||
| def _get_model_deployment_response( | ||
| self, | ||
| model_deployment_id: str, | ||
| payload: dict, | ||
| route_override_header: Optional[str], | ||
| payload: dict | ||
| ): | ||
| """ | ||
| Returns the model deployment inference response in a streaming fashion. | ||
|
|
@@ -272,53 +364,173 @@ def _get_model_deployment_response( | |
| """ | ||
|
|
||
| model_deployment = AquaDeploymentApp().get(model_deployment_id) | ||
| endpoint = model_deployment.endpoint + "/predictWithResponseStream" | ||
| endpoint_type = model_deployment.environment_variables.get( | ||
| "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT | ||
| ) | ||
| aqua_client = Client(endpoint=endpoint) | ||
|
|
||
| if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( | ||
| endpoint_type, | ||
| route_override_header, | ||
| ): | ||
| endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1" | ||
|
|
||
| required_keys = ["endpoint_type", "prompt", "model"] | ||
| missing = [k for k in required_keys if k not in payload] | ||
|
|
||
| if missing: | ||
| raise HTTPError(400, f"Missing required payload keys: {', '.join(missing)}") | ||
|
|
||
| endpoint_type = payload["endpoint_type"] | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| aqua_client = OpenAI(base_url=endpoint) | ||
|
|
||
| allowed = { | ||
| "max_tokens", | ||
| "temperature", | ||
| "top_p", | ||
| "stop", | ||
| "n", | ||
| "presence_penalty", | ||
| "frequency_penalty", | ||
| "logprobs", | ||
| "user", | ||
| "echo", | ||
| } | ||
| responses_allowed = { | ||
| "temperature", "top_p" | ||
| } | ||
|
|
||
| # normalize and filter | ||
| if payload.get("stop") == []: | ||
| payload["stop"] = None | ||
|
|
||
| encoded_image = "NA" | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if encoded_image in payload : | ||
| encoded_image = payload["encoded_image"] | ||
|
|
||
| model = payload.pop("model") | ||
| filtered = {k: v for k, v in payload.items() if k in allowed} | ||
| responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed} | ||
|
|
||
| if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA": | ||
| try: | ||
| for chunk in aqua_client.chat( | ||
| messages=payload.pop("messages"), | ||
| payload=payload, | ||
| stream=True, | ||
| ): | ||
| try: | ||
| if "text" in chunk["choices"][0]: | ||
| yield chunk["choices"][0]["text"] | ||
| elif "content" in chunk["choices"][0]["delta"]: | ||
| yield chunk["choices"][0]["delta"]["content"] | ||
| except Exception as e: | ||
| logger.debug( | ||
| f"Exception occurred while parsing streaming response: {e}" | ||
| ) | ||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [{"role": "user", "content": payload["prompt"]}], | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "stream": True, | ||
| **filtered | ||
| } | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| stream = aqua_client.chat.completions.create(**api_kwargs) | ||
|
|
||
| for chunk in stream: | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
|
|
||
| elif ( | ||
| endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT | ||
| and encoded_image != "NA" | ||
| ): | ||
| file_type = payload.pop("file_type") | ||
| if file_type.startswith("image"): | ||
|
||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "text", "text": payload["prompt"]}, | ||
VipulMascarenhas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
| "type": "image_url", | ||
| "image_url": {"url": f"{encoded_image}"}, | ||
| }, | ||
| ], | ||
| } | ||
| ], | ||
| "stream": True, | ||
| **filtered | ||
| } | ||
|
|
||
| # Add chat_template for image-based chat completions | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| response = aqua_client.chat.completions.create(**api_kwargs) | ||
|
|
||
| elif file_type.startswith("audio"): | ||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "text", "text": payload["prompt"]}, | ||
| { | ||
| "type": "audio_url", | ||
| "audio_url": {"url": f"{encoded_image}"}, | ||
VipulMascarenhas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| }, | ||
| ], | ||
| } | ||
| ], | ||
| "stream": True, | ||
| **filtered | ||
| } | ||
|
|
||
| # Add chat_template for audio-based chat completions | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| response = aqua_client.chat.completions.create(**api_kwargs) | ||
| try: | ||
| for chunk in response: | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece: | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
| elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: | ||
| try: | ||
| for chunk in aqua_client.generate( | ||
| prompt=payload.pop("prompt"), | ||
| payload=payload, | ||
| stream=True, | ||
| for chunk in aqua_client.completions.create( | ||
| prompt=payload["prompt"], stream=True, model=model, **filtered | ||
| ): | ||
| try: | ||
| yield chunk["choices"][0]["text"] | ||
| except Exception as e: | ||
| logger.debug( | ||
| f"Exception occurred while parsing streaming response: {e}" | ||
| ) | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
|
|
||
| elif endpoint_type == PredictEndpoints.RESPONSES: | ||
| api_kwargs = { | ||
| "model": model, | ||
| "input": payload["prompt"], | ||
| "stream": True | ||
| } | ||
|
|
||
| if "temperature" in responses_filtered: | ||
| api_kwargs["temperature"] = responses_filtered["temperature"] | ||
| if "top_p" in responses_filtered: | ||
| api_kwargs["top_p"] = responses_filtered["top_p"] | ||
|
|
||
| response = aqua_client.responses.create(**api_kwargs) | ||
| try: | ||
| for chunk in response: | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
VipulMascarenhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise HTTPError(500, str(ex)) | ||
| else: | ||
| raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}") | ||
|
|
||
| @handle_exceptions | ||
| def post(self, model_deployment_id): | ||
|
|
@@ -340,24 +552,24 @@ def post(self, model_deployment_id): | |
| prompt = input_data.get("prompt") | ||
| messages = input_data.get("messages") | ||
|
|
||
|
|
||
| if not prompt and not messages: | ||
| raise HTTPError( | ||
| 400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages") | ||
| ) | ||
| if not input_data.get("model"): | ||
| raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) | ||
| route_override_header = self.request.headers.get("route", None) | ||
| self.set_header("Content-Type", "text/event-stream") | ||
| response_gen = self._get_model_deployment_response( | ||
| model_deployment_id, input_data, route_override_header | ||
| model_deployment_id, input_data | ||
| ) | ||
| try: | ||
| for chunk in response_gen: | ||
| self.write(chunk) | ||
| self.flush() | ||
| self.finish() | ||
| except Exception as ex: | ||
| self.set_status(ex.status_code) | ||
| self.set_status(getattr(ex, "status_code", 500)) | ||
| self.write({"message": "Error occurred", "reason": str(ex)}) | ||
| self.finish() | ||
|
|
||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.