-
Notifications
You must be signed in to change notification settings - Fork 697
FEAT expand TargetCapabilities #1464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8900ed5
9a39cc7
a687aa1
4537a8b
707f42d
1fe3f93
80e5a9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,12 +23,21 @@ class PromptTarget(Identifiable): | |
|
|
||
| _memory: MemoryInterface | ||
|
|
||
| #: A list of PromptConverters that are supported by the prompt target. | ||
| #: An empty list implies that the prompt target supports all converters. | ||
| # A list of PromptConverters that are supported by the prompt target. | ||
| # An empty list implies that the prompt target supports all converters. | ||
| supported_converters: list[Any] | ||
|
|
||
| _identifier: Optional[ComponentIdentifier] = None | ||
|
|
||
| # Class-level default capabilities for this target type. | ||
| # | ||
| # Subclasses **should** override this when their capabilities differ from the base | ||
| # defaults (e.g., to declare multi-turn support or non-text modalities). | ||
| # Overriding is *optional* — if a subclass does not define ``_DEFAULT_CAPABILITIES``, | ||
| # it inherits the base-class default (text-only, single-turn, no JSON response). | ||
| # | ||
| # Per-instance overrides are also possible via the ``custom_capabilities`` | ||
| # constructor parameter, which takes precedence over the class-level value. | ||
| _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities() | ||
|
|
||
| def __init__( | ||
|
|
@@ -38,7 +47,7 @@ def __init__( | |
| endpoint: str = "", | ||
| model_name: str = "", | ||
| underlying_model: Optional[str] = None, | ||
| capabilities: Optional[TargetCapabilities] = None, | ||
| custom_capabilities: Optional[TargetCapabilities] = None, | ||
| ) -> None: | ||
| """ | ||
| Initialize the PromptTarget. | ||
|
|
@@ -52,7 +61,7 @@ def __init__( | |
| identification purposes. This is useful when the deployment name in Azure differs | ||
| from the actual model. If not provided, `model_name` will be used for the identifier. | ||
| Defaults to None. | ||
| capabilities (TargetCapabilities, Optional): Override the default capabilities for | ||
| custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we could have a mapping of known default capabilities here, or potentially retrieved from a list in target_capabilites class.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also talked about a method to discover these which I think will be useful. But the defaults we know could go a long way early on :) |
||
| this target instance. Useful for targets whose capabilities depend on deployment | ||
| configuration (e.g., Playwright, HTTP). If None, uses the class-level | ||
| ``_DEFAULT_CAPABILITIES``. Defaults to None. | ||
|
|
@@ -63,7 +72,9 @@ def __init__( | |
| self._endpoint = endpoint | ||
| self._model_name = model_name | ||
| self._underlying_model = underlying_model | ||
| self._capabilities = capabilities if capabilities is not None else type(self)._DEFAULT_CAPABILITIES | ||
| self._capabilities = ( | ||
hannahwestra25 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| custom_capabilities if custom_capabilities is not None else type(self)._DEFAULT_CAPABILITIES | ||
hannahwestra25 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| if self._verbose: | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
@@ -78,14 +89,38 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: | |
| but some (like response target with tool calls) may return multiple messages. | ||
| """ | ||
|
|
||
| @abc.abstractmethod | ||
| def _validate_request(self, *, message: Message) -> None: | ||
| """ | ||
| Validate the provided message. | ||
|
|
||
| Args: | ||
| message: The message to validate. | ||
|
|
||
| Raises: | ||
| ValueError: if the target does not support the provided message pieces or if the message | ||
| violates any constraints based on the target's capabilities. This includes checks | ||
| for the number of message pieces, supported data types, and multi-turn conversation support. | ||
|
|
||
| """ | ||
| n_pieces = len(message.message_pieces) | ||
| if not self.capabilities.supports_multi_message_pieces and n_pieces != 1: | ||
| raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") | ||
|
|
||
| for piece in message.message_pieces: | ||
| piece_type = piece.converted_value_data_type | ||
| supported_types_flat = {t for combo in self.capabilities.input_modalities for t in combo} | ||
| if piece_type not in supported_types_flat: | ||
| supported_types = ", ".join(sorted(supported_types_flat)) | ||
| raise ValueError( | ||
| f"This target supports only the following data types: {supported_types}. Received: {piece_type}." | ||
| ) | ||
|
|
||
| if not self.capabilities.supports_multi_turn: | ||
| request = message.message_pieces[0] | ||
| messages = self._memory.get_message_pieces(conversation_id=request.conversation_id) | ||
|
|
||
| if len(messages) > 0: | ||
| raise ValueError("This target only supports a single turn conversation.") | ||
|
|
||
| def set_model_name(self, *, model_name: str) -> None: | ||
| """ | ||
|
|
@@ -133,7 +168,7 @@ def _create_identifier( | |
| "endpoint": self._endpoint, | ||
| "model_name": model_name, | ||
| "max_requests_per_minute": self._max_requests_per_minute, | ||
| "supports_multi_turn": self.supports_multi_turn, | ||
| "supports_multi_turn": self.capabilities.supports_multi_turn, | ||
| } | ||
| if params: | ||
| all_params.update(params) | ||
|
|
@@ -155,19 +190,6 @@ def capabilities(self) -> TargetCapabilities: | |
| """ | ||
| return self._capabilities | ||
|
|
||
| @property | ||
| def supports_multi_turn(self) -> bool: | ||
| """ | ||
| Whether this target supports multi-turn conversations. | ||
|
|
||
| Convenience property that delegates to ``self.capabilities.supports_multi_turn``. | ||
|
|
||
| Returns: | ||
| bool: False by default. Subclasses declare multi-turn support by setting | ||
| ``_DEFAULT_CAPABILITIES`` or passing ``capabilities`` to the constructor. | ||
| """ | ||
| return self._capabilities.supports_multi_turn | ||
|
|
||
| def _build_identifier(self) -> ComponentIdentifier: | ||
| """ | ||
| Build the identifier for this target. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.