Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyrit/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from pyrit.auth.authenticator import Authenticator
from pyrit.auth.azure_auth import (
AsyncTokenProviderCredential,
AzureAuth,
TokenProviderCredential,
ensure_async_token_provider,
get_azure_async_token_provider,
get_azure_openai_auth,
get_azure_token_provider,
Expand All @@ -19,12 +21,14 @@
from pyrit.auth.manual_copilot_authenticator import ManualCopilotAuthenticator

__all__ = [
"AsyncTokenProviderCredential",
"Authenticator",
"AzureAuth",
"AzureStorageAuth",
"CopilotAuthenticator",
"ManualCopilotAuthenticator",
"TokenProviderCredential",
"ensure_async_token_provider",
"get_azure_token_provider",
"get_azure_async_token_provider",
"get_default_azure_scope",
Expand Down
100 changes: 99 additions & 1 deletion pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import inspect
import logging
import time
from typing import TYPE_CHECKING, Any, Union, cast
Expand All @@ -22,7 +23,7 @@
)

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable

import azure.cognitiveservices.speech as speechsdk

Expand Down Expand Up @@ -66,6 +67,103 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
return AccessToken(str(token), expires_on)


class AsyncTokenProviderCredential:
"""
Async wrapper to convert a token provider callable into an Azure AsyncTokenCredential.

This class bridges the gap between token provider functions (sync or async) and Azure SDK
async clients that require an AsyncTokenCredential object (with async def get_token).
"""

def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None:
"""
Initialize AsyncTokenProviderCredential.

Args:
token_provider: A callable that returns a token string (sync) or an awaitable that
returns a token string (async). Both are supported transparently.
"""
self._token_provider = token_provider

async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""
Get an access token asynchronously.

Args:
scopes: Token scopes (ignored as the scope is already configured in the token provider).
kwargs: Additional arguments (ignored).

Returns:
AccessToken: The access token with expiration time.
"""
result = self._token_provider()
if inspect.isawaitable(result):
token = await result
else:
token = result
expires_on = int(time.time()) + 3600
return AccessToken(str(token), expires_on)

async def close(self) -> None:
"""No-op close for protocol compliance. The callable provider does not hold resources."""

async def __aenter__(self) -> AsyncTokenProviderCredential:
"""
Enter the async context manager.

Returns:
AsyncTokenProviderCredential: This credential instance.
"""
return self

async def __aexit__(self, *args: Any) -> None:
"""Exit the async context manager."""
await self.close()


def ensure_async_token_provider(
api_key: str | Callable[[], str | Awaitable[str]] | None,
) -> str | Callable[[], Awaitable[str]] | None:
"""
Ensure the api_key is either a string or an async callable.

If a synchronous callable token provider is provided, it's automatically wrapped
in an async function to make it compatible with async Azure SDK clients.

Args:
api_key: Either a string API key or a callable that returns a token (sync or async).

Returns:
Either a string API key or an async callable that returns a token.
"""
if api_key is None or isinstance(api_key, str) or not callable(api_key):
return api_key

# Check if the callable is already async
if inspect.iscoroutinefunction(api_key):
return api_key

# Wrap synchronous token provider in async function
logger.debug(
"Detected synchronous token provider."
" Automatically wrapping in async function for compatibility with async client."
)

async def async_token_provider() -> str:
"""
Async wrapper for synchronous token provider.

Returns:
str: The token string from the synchronous provider.
"""
result = api_key()
if inspect.isawaitable(result):
return await result
return result

return async_token_provider


class AzureAuth(Authenticator):
"""
Azure CLI Authentication.
Expand Down
45 changes: 2 additions & 43 deletions pyrit/prompt_target/openai/openai_target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import json
import logging
import re
Expand All @@ -23,7 +22,7 @@
AuthenticationError,
)

from pyrit.auth import get_azure_openai_auth
from pyrit.auth import ensure_async_token_provider, get_azure_openai_auth
from pyrit.common import default_values
from pyrit.exceptions.exception_classes import (
RateLimitException,
Expand All @@ -41,46 +40,6 @@
logger = logging.getLogger(__name__)


def _ensure_async_token_provider(
api_key: Optional[str | Callable[[], str | Awaitable[str]]],
) -> Optional[str | Callable[[], Awaitable[str]]]:
"""
Ensure the api_key is either a string or an async callable.

If a synchronous callable token provider is provided, it's automatically wrapped
in an async function to make it compatible with AsyncOpenAI.

Args:
api_key: Either a string API key or a callable that returns a token (sync or async).

Returns:
Either a string API key or an async callable that returns a token.
"""
if api_key is None or isinstance(api_key, str) or not callable(api_key):
return api_key

# Check if the callable is already async
if asyncio.iscoroutinefunction(api_key):
return api_key

# Wrap synchronous token provider in async function
logger.info(
"Detected synchronous token provider."
" Automatically wrapping in async function for compatibility with AsyncOpenAI."
)

async def async_token_provider() -> str:
"""
Async wrapper for synchronous token provider.

Returns:
str: The token string from the synchronous provider.
"""
return api_key() # type: ignore[return-value]

return async_token_provider


class OpenAITarget(PromptTarget):
"""
Abstract base class for OpenAI-based prompt targets.
Expand Down Expand Up @@ -198,7 +157,7 @@ def __init__(
)

# Ensure api_key is async-compatible (wrap sync token providers if needed)
self._api_key = _ensure_async_token_provider(resolved_api_key)
self._api_key = ensure_async_token_provider(resolved_api_key)

self._initialize_openai_client()

Expand Down
55 changes: 22 additions & 33 deletions pyrit/score/float_scale/azure_content_filter_scorer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import base64
import inspect
from collections.abc import Callable
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Optional

from azure.ai.contentsafety import ContentSafetyClient
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.ai.contentsafety.models import (
AnalyzeImageOptions,
AnalyzeImageResult,
Expand All @@ -18,7 +17,7 @@
)
from azure.core.credentials import AzureKeyCredential

from pyrit.auth import TokenProviderCredential, get_azure_token_provider
from pyrit.auth import AsyncTokenProviderCredential, ensure_async_token_provider, get_azure_async_token_provider
from pyrit.common import default_values
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import (
Expand All @@ -38,6 +37,8 @@
from pyrit.score.scorer_evaluation.scorer_evaluator import ScorerEvalDatasetFiles
from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics

logger = logging.getLogger(__name__)


class AzureContentFilterScorer(FloatScaleScorer):
"""
Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(
self,
*,
endpoint: Optional[str | None] = None,
api_key: Optional[str | Callable[[], str] | None] = None,
api_key: Optional[str | Callable[[], str | Awaitable[str]] | None] = None,
harm_categories: Optional[list[TextCategory]] = None,
validator: Optional[ScorerPromptValidator] = None,
) -> None:
Expand All @@ -104,13 +105,12 @@ def __init__(
Args:
endpoint (Optional[str | None]): The endpoint URL for the Azure Content Safety service.
Defaults to the `ENDPOINT_URI_ENVIRONMENT_VARIABLE` environment variable.
api_key (Optional[str | Callable[[], str] | None]):
api_key (Optional[str | Callable[[], str | Awaitable[str]] | None]):
The API key for accessing the Azure Content Safety service,
or a synchronous callable that returns an access token. Async token providers
are not supported. If not provided (via parameter
or environment variable), Entra ID authentication is used automatically.
You can also explicitly pass a token provider from pyrit.auth
(e.g., get_azure_token_provider('https://cognitiveservices.azure.com/.default')).
or a callable that returns an access token. Both synchronous and asynchronous
token providers are supported. Sync providers are automatically wrapped for
async compatibility. If not provided (via parameter or environment variable),
Entra ID authentication is used automatically.
Defaults to the `API_KEY_ENVIRONMENT_VARIABLE` environment variable.
harm_categories (Optional[list[TextCategory]]): The harm categories you want to query for as
defined in azure.ai.contentsafety.models.TextCategory. If not provided, defaults to all categories.
Expand All @@ -129,36 +129,25 @@ def __init__(
)

# API key: use passed value, env var, or fall back to Entra ID for Azure endpoints
resolved_api_key: str | Callable[[], str]
resolved_api_key: str | Callable[[], str | Awaitable[str]]
if api_key is not None and callable(api_key):
if asyncio.iscoroutinefunction(api_key):
raise ValueError(
"Async token providers are not supported by AzureContentFilterScorer. "
"Use a synchronous token provider (e.g., get_azure_token_provider) instead."
)
# Guard against sync callables that return coroutines/awaitables (e.g., lambda: async_fn())
test_result = api_key()
if inspect.isawaitable(test_result):
if hasattr(test_result, "close"):
test_result.close() # prevent "coroutine was never awaited" warning
raise ValueError(
"The provided token provider returns a coroutine/awaitable, which is not supported "
"by AzureContentFilterScorer. Use a synchronous token provider instead."
)
resolved_api_key = api_key
else:
api_key_value = default_values.get_non_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
resolved_api_key = api_key_value or get_azure_token_provider("https://cognitiveservices.azure.com/.default")
resolved_api_key = api_key_value or get_azure_async_token_provider(
"https://cognitiveservices.azure.com/.default"
)

self._api_key = resolved_api_key
# Ensure api_key is async-compatible (wrap sync token providers if needed)
self._api_key = ensure_async_token_provider(resolved_api_key)

# Create ContentSafetyClient with appropriate credential
if self._endpoint is not None:
if callable(self._api_key):
# Token provider - create a TokenCredential wrapper
credential = TokenProviderCredential(self._api_key)
# Token provider - create an AsyncTokenCredential wrapper
credential = AsyncTokenProviderCredential(self._api_key)
self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential)
else:
# String API key
Expand Down Expand Up @@ -291,7 +280,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op
categories=self._category_values,
output_type="EightSeverityLevels",
)
text_result = self._azure_cf_client.analyze_text(text_request_options)
text_result = await self._azure_cf_client.analyze_text(text_request_options)
filter_results.append(text_result)

elif message_piece.converted_value_data_type == "image_path":
Expand All @@ -301,7 +290,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op
image_request_options = AnalyzeImageOptions(
image=image_data, categories=self._category_values, output_type="FourSeverityLevels"
)
image_result = self._azure_cf_client.analyze_image(image_request_options)
image_result = await self._azure_cf_client.analyze_image(image_request_options)
filter_results.append(image_result)

# Collect all scores from all chunks/images
Expand Down
Loading
Loading