Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def _setup_async(self, *, context: ChunkedRequestAttackContext) -> None:
Raises:
ValueError: If the objective target does not support multi-turn conversations.
"""
if not self._objective_target.supports_multi_turn:
if not self._objective_target.capabilities.supports_multi_turn:
raise ValueError(
"ChunkedRequestAttack requires a multi-turn target. "
"The objective target does not support multi-turn conversations."
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None:
Raises:
ValueError: If the objective target does not support multi-turn conversations.
"""
if not self._objective_target.supports_multi_turn:
if not self._objective_target.capabilities.supports_multi_turn:
raise ValueError(
"CrescendoAttack requires a multi-turn target. Crescendo fundamentally relies on "
"multi-turn conversation history to gradually escalate prompts. "
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/multi_prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None:
Raises:
ValueError: If the objective target does not support multi-turn conversations.
"""
if not self._objective_target.supports_multi_turn:
if not self._objective_target.capabilities.supports_multi_turn:
raise ValueError(
"MultiPromptSendingAttack requires a multi-turn target. "
"The objective target does not support multi-turn conversations."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _rotate_conversation_for_single_turn_target(
Args:
context: The current attack context.
"""
if self._objective_target.supports_multi_turn:
if self._objective_target.capabilities.supports_multi_turn:
return

if context.executed_turns == 0:
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def duplicate(self) -> "_TreeOfAttacksNode":
# For single-turn targets, duplicate only the system messages (e.g., system prompt
# from prepended conversation) so the target retains its configuration without
# carrying over attack turn history that would cause validation errors.
if self._objective_target.supports_multi_turn:
if self._objective_target.capabilities.supports_multi_turn:
duplicate_node.objective_target_conversation_id = self._memory.duplicate_conversation(
conversation_id=self.objective_target_conversation_id
)
Expand Down
40 changes: 24 additions & 16 deletions pyrit/prompt_target/azure_blob_storage_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,13 +50,29 @@ class AzureBlobStorageTarget(PromptTarget):
AZURE_STORAGE_CONTAINER_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_CONTAINER_URL"
SAS_TOKEN_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_SAS_TOKEN"

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(
supports_multi_message_pieces=True,
input_modalities=frozenset(
{
frozenset(["text"]),
frozenset(["url"]),
}
),
output_modalities=frozenset(
{
frozenset(["url"]),
}
),
)

def __init__(
self,
*,
container_url: Optional[str] = None,
sas_token: Optional[str] = None,
blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT,
max_requests_per_minute: Optional[int] = None,
custom_capabilities: Optional[TargetCapabilities] = None,
) -> None:
"""
Initialize the Azure Blob Storage target.
Expand All @@ -68,6 +85,8 @@ def __init__(
blob_content_type (SupportedContentType): The content type for blobs.
Defaults to PLAIN_TEXT.
max_requests_per_minute (int, Optional): Maximum number of requests per minute.
custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for
this target instance. Defaults to None.
"""
self._blob_content_type: str = blob_content_type.value

Expand All @@ -78,7 +97,11 @@ def __init__(
self._sas_token: Optional[str] = sas_token
self._client_async: Optional[AsyncContainerClient] = None

super().__init__(endpoint=self._container_url, max_requests_per_minute=max_requests_per_minute)
super().__init__(
endpoint=self._container_url,
max_requests_per_minute=max_requests_per_minute,
custom_capabilities=custom_capabilities,
)

def _build_identifier(self) -> ComponentIdentifier:
"""
Expand Down Expand Up @@ -196,18 +219,3 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
)

return [response]

def _validate_request(self, *, message: Message) -> None:
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received {n_pieces} pieces")

piece_type = message.message_pieces[0].converted_value_data_type
if piece_type not in ["text", "url"]:
raise ValueError(f"This target only supports text and url prompt input. Received: {piece_type}.")

request = message.message_pieces[0]
messages = self._memory.get_message_pieces(conversation_id=request.conversation_id)

if len(messages) > 0:
raise ValueError("This target only supports a single turn conversation.")
21 changes: 11 additions & 10 deletions pyrit/prompt_target/azure_ml_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
construct_response_from_request,
)
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p

logger = logging.getLogger(__name__)
Expand All @@ -40,6 +41,10 @@ class AzureMLChatTarget(PromptChatTarget):
endpoint_uri_environment_variable: str = "AZURE_ML_MANAGED_ENDPOINT"
api_key_environment_variable: str = "AZURE_ML_KEY"

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(
supports_multi_message_pieces=True, supports_editable_history=True
)

def __init__(
self,
*,
Expand All @@ -52,6 +57,7 @@ def __init__(
top_p: float = 1.0,
repetition_penalty: float = 1.0,
max_requests_per_minute: Optional[int] = None,
custom_capabilities: Optional[TargetCapabilities] = None,
**param_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -89,7 +95,11 @@ def __init__(
env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint
)
PromptChatTarget.__init__(
self, max_requests_per_minute=max_requests_per_minute, endpoint=endpoint_value, model_name=model_name
self,
max_requests_per_minute=max_requests_per_minute,
endpoint=endpoint_value,
model_name=model_name,
custom_capabilities=custom_capabilities,
)

self._initialize_vars(endpoint=endpoint, api_key=api_key)
Expand Down Expand Up @@ -272,12 +282,3 @@ def _get_headers(self) -> dict[str, str]:

def _validate_request(self, *, message: Message) -> None:
pass

def is_json_response_supported(self) -> bool:
"""
Check if the target supports JSON as a response format.

Returns:
bool: True if JSON response is supported, False otherwise.
"""
return False
22 changes: 7 additions & 15 deletions pyrit/prompt_target/common/prompt_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
from typing import Optional

from pyrit.identifiers import ComponentIdentifier
Expand All @@ -22,7 +21,9 @@ class PromptChatTarget(PromptTarget):
Realtime chat targets or OpenAI completions are NOT PromptChatTargets. You don't send the conversation history.
"""

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=True)
_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(
supports_multi_turn=True, supports_multi_message_pieces=True
)

def __init__(
self,
Expand All @@ -31,7 +32,7 @@ def __init__(
endpoint: str = "",
model_name: str = "",
underlying_model: Optional[str] = None,
capabilities: Optional[TargetCapabilities] = None,
custom_capabilities: Optional[TargetCapabilities] = None,
) -> None:
"""
Initialize the PromptChatTarget.
Expand All @@ -43,15 +44,15 @@ def __init__(
underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for
identification purposes. This is useful when the deployment name in Azure differs
from the actual model. Defaults to None.
capabilities (TargetCapabilities, Optional): Override the default capabilities for
custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for
this target instance. If None, uses the class-level defaults. Defaults to None.
"""
super().__init__(
max_requests_per_minute=max_requests_per_minute,
endpoint=endpoint,
model_name=model_name,
underlying_model=underlying_model,
capabilities=capabilities,
custom_capabilities=custom_capabilities,
)

def set_system_prompt(
Expand Down Expand Up @@ -85,15 +86,6 @@ def set_system_prompt(
).to_message()
)

@abc.abstractmethod
def is_json_response_supported(self) -> bool:
"""
Abstract method to determine if JSON response format is supported by the target.

Returns:
bool: True if JSON response is supported, False otherwise.
"""

def is_response_format_json(self, message_piece: MessagePiece) -> bool:
"""
Check if the response format is JSON and ensure the target supports it.
Expand Down Expand Up @@ -127,7 +119,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp
"""
config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata)

if config.enabled and not self.is_json_response_supported():
if config.enabled and not self.capabilities.supports_json_output:
target_name = self.get_identifier().class_name
raise ValueError(f"This target {target_name} does not support JSON response format.")

Expand Down
62 changes: 42 additions & 20 deletions pyrit/prompt_target/common/prompt_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@ class PromptTarget(Identifiable):

_memory: MemoryInterface

#: A list of PromptConverters that are supported by the prompt target.
#: An empty list implies that the prompt target supports all converters.
# A list of PromptConverters that are supported by the prompt target.
# An empty list implies that the prompt target supports all converters.
supported_converters: list[Any]

_identifier: Optional[ComponentIdentifier] = None

# Class-level default capabilities for this target type.
#
# Subclasses **should** override this when their capabilities differ from the base
# defaults (e.g., to declare multi-turn support or non-text modalities).
# Overriding is *optional* — if a subclass does not define ``_DEFAULT_CAPABILITIES``,
# it inherits the base-class default (text-only, single-turn, no JSON response).
#
# Per-instance overrides are also possible via the ``custom_capabilities``
# constructor parameter, which takes precedence over the class-level value.
_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities()

def __init__(
Expand All @@ -38,7 +47,7 @@ def __init__(
endpoint: str = "",
model_name: str = "",
underlying_model: Optional[str] = None,
capabilities: Optional[TargetCapabilities] = None,
custom_capabilities: Optional[TargetCapabilities] = None,
) -> None:
"""
Initialize the PromptTarget.
Expand All @@ -52,7 +61,7 @@ def __init__(
identification purposes. This is useful when the deployment name in Azure differs
from the actual model. If not provided, `model_name` will be used for the identifier.
Defaults to None.
capabilities (TargetCapabilities, Optional): Override the default capabilities for
custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could have a mapping of known default capabilities here, or potentially retrieved from a list in target_capabilites class.

if underlying_model == "gpt-5.1":
  _default = X
elif ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also talked about a method to discover these which I think will be useful. But the defaults we know could go a long way early on :)

this target instance. Useful for targets whose capabilities depend on deployment
configuration (e.g., Playwright, HTTP). If None, uses the class-level
``_DEFAULT_CAPABILITIES``. Defaults to None.
Expand All @@ -63,7 +72,9 @@ def __init__(
self._endpoint = endpoint
self._model_name = model_name
self._underlying_model = underlying_model
self._capabilities = capabilities if capabilities is not None else type(self)._DEFAULT_CAPABILITIES
self._capabilities = (
custom_capabilities if custom_capabilities is not None else type(self)._DEFAULT_CAPABILITIES
)

if self._verbose:
logging.basicConfig(level=logging.INFO)
Expand All @@ -78,14 +89,38 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
but some (like response target with tool calls) may return multiple messages.
"""

@abc.abstractmethod
def _validate_request(self, *, message: Message) -> None:
"""
Validate the provided message.

Args:
message: The message to validate.

Raises:
ValueError: if the target does not support the provided message pieces or if the message
violates any constraints based on the target's capabilities. This includes checks
for the number of message pieces, supported data types, and multi-turn conversation support.

"""
n_pieces = len(message.message_pieces)
if not self.capabilities.supports_multi_message_pieces and n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")

for piece in message.message_pieces:
piece_type = piece.converted_value_data_type
supported_types_flat = {t for combo in self.capabilities.input_modalities for t in combo}
if piece_type not in supported_types_flat:
supported_types = ", ".join(sorted(supported_types_flat))
raise ValueError(
f"This target supports only the following data types: {supported_types}. Received: {piece_type}."
)

if not self.capabilities.supports_multi_turn:
request = message.message_pieces[0]
messages = self._memory.get_message_pieces(conversation_id=request.conversation_id)

if len(messages) > 0:
raise ValueError("This target only supports a single turn conversation.")

def set_model_name(self, *, model_name: str) -> None:
"""
Expand Down Expand Up @@ -133,7 +168,7 @@ def _create_identifier(
"endpoint": self._endpoint,
"model_name": model_name,
"max_requests_per_minute": self._max_requests_per_minute,
"supports_multi_turn": self.supports_multi_turn,
"supports_multi_turn": self.capabilities.supports_multi_turn,
}
if params:
all_params.update(params)
Expand All @@ -155,19 +190,6 @@ def capabilities(self) -> TargetCapabilities:
"""
return self._capabilities

@property
def supports_multi_turn(self) -> bool:
"""
Whether this target supports multi-turn conversations.

Convenience property that delegates to ``self.capabilities.supports_multi_turn``.

Returns:
bool: False by default. Subclasses declare multi-turn support by setting
``_DEFAULT_CAPABILITIES`` or passing ``capabilities`` to the constructor.
"""
return self._capabilities.supports_multi_turn

def _build_identifier(self) -> ComponentIdentifier:
"""
Build the identifier for this target.
Expand Down
Loading
Loading