From b9f8edcacba73e1c504c98e9b6d4630ea2a1e4a2 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Fri, 13 Mar 2026 16:41:32 -0400 Subject: [PATCH 1/6] Adding support for async in azure content filter scorer --- pyrit/auth/__init__.py | 2 + pyrit/auth/azure_auth.py | 45 +++++++++- .../azure_content_filter_scorer.py | 89 ++++++++++++------- tests/integration/mocks.py | 4 +- .../test_azure_content_filter_integration.py | 19 ++-- tests/unit/score/test_azure_content_filter.py | 34 ++++--- 6 files changed, 132 insertions(+), 61 deletions(-) diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 4074809e51..02cd90b1dd 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -7,6 +7,7 @@ from pyrit.auth.authenticator import Authenticator from pyrit.auth.azure_auth import ( + AsyncTokenProviderCredential, AzureAuth, TokenProviderCredential, get_azure_async_token_provider, @@ -19,6 +20,7 @@ from pyrit.auth.manual_copilot_authenticator import ManualCopilotAuthenticator __all__ = [ + "AsyncTokenProviderCredential", "Authenticator", "AzureAuth", "AzureStorageAuth", diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index b606189636..ce9aed3028 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -3,6 +3,7 @@ from __future__ import annotations +import inspect import logging import time from typing import TYPE_CHECKING, Any, Union, cast @@ -23,7 +24,7 @@ ) if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable import azure.cognitiveservices.speech as speechsdk @@ -67,6 +68,48 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: return AccessToken(str(token), expires_on) +class AsyncTokenProviderCredential: + """ + Async wrapper to convert a token provider callable into an Azure AsyncTokenCredential. + + This class bridges the gap between token provider functions (sync or async) and Azure SDK + async clients that require an AsyncTokenCredential object (with async def get_token). + """ + + def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None: + """ + Initialize AsyncTokenProviderCredential. + + Args: + token_provider: A callable that returns a token string (sync) or an awaitable that + returns a token string (async). Both are supported transparently. + """ + self._token_provider = token_provider + + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """ + Get an access token asynchronously. + + Args: + scopes: Token scopes (ignored as the scope is already configured in the token provider). + kwargs: Additional arguments (ignored). + + Returns: + AccessToken: The access token with expiration time. + """ + result = self._token_provider() + if inspect.isawaitable(result): + token = await result + else: + token = result + expires_on = int(time.time()) + 3600 + return AccessToken(str(token), expires_on) + + async def close(self) -> None: + """No-op close for protocol compliance. The callable provider does not hold resources.""" + pass + + class AzureAuth(Authenticator): """ Azure CLI Authentication. diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 6b587d4f30..1ca6781919 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -1,13 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import asyncio import base64 import inspect -from collections.abc import Callable +import logging +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Optional -from azure.ai.contentsafety import ContentSafetyClient +from azure.ai.contentsafety.aio import ContentSafetyClient from azure.ai.contentsafety.models import ( AnalyzeImageOptions, AnalyzeImageResult, @@ -18,7 +18,7 @@ ) from azure.core.credentials import AzureKeyCredential -from pyrit.auth import TokenProviderCredential, get_azure_token_provider +from pyrit.auth import AsyncTokenProviderCredential, get_azure_async_token_provider from pyrit.common import default_values from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( @@ -38,6 +38,43 @@ from pyrit.score.scorer_evaluation.scorer_evaluator import ScorerEvalDatasetFiles from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics +logger = logging.getLogger(__name__) + + +def _ensure_async_token_provider( + api_key: str | Callable[[], str | Awaitable[str]] | None, +) -> str | Callable[[], Awaitable[str]] | None: + """ + Ensure the api_key is either a string or an async callable. + + If a synchronous callable token provider is provided, it's automatically wrapped + in an async function to make it compatible with the async ContentSafetyClient. + + Args: + api_key: Either a string API key or a callable that returns a token (sync or async). + + Returns: + Either a string API key or an async callable that returns a token. + """ + if api_key is None or isinstance(api_key, str) or not callable(api_key): + return api_key + + # Check if the callable is already async + if inspect.iscoroutinefunction(api_key): + return api_key + + # Wrap synchronous token provider in async function + logger.info( + "Detected synchronous token provider." + " Automatically wrapping in async function for compatibility with async ContentSafetyClient." + ) + + async def async_token_provider() -> str: + """Async wrapper for synchronous token provider.""" + return api_key() # type: ignore[return-value] + + return async_token_provider + class AzureContentFilterScorer(FloatScaleScorer): """ @@ -94,7 +131,7 @@ def __init__( self, *, endpoint: Optional[str | None] = None, - api_key: Optional[str | Callable[[], str] | None] = None, + api_key: Optional[str | Callable[[], str | Awaitable[str]] | None] = None, harm_categories: Optional[list[TextCategory]] = None, validator: Optional[ScorerPromptValidator] = None, ) -> None: @@ -104,13 +141,12 @@ def __init__( Args: endpoint (Optional[str | None]): The endpoint URL for the Azure Content Safety service. Defaults to the `ENDPOINT_URI_ENVIRONMENT_VARIABLE` environment variable. - api_key (Optional[str | Callable[[], str] | None]): + api_key (Optional[str | Callable[[], str | Awaitable[str]] | None]): The API key for accessing the Azure Content Safety service, - or a synchronous callable that returns an access token. Async token providers - are not supported. If not provided (via parameter - or environment variable), Entra ID authentication is used automatically. - You can also explicitly pass a token provider from pyrit.auth - (e.g., get_azure_token_provider('https://cognitiveservices.azure.com/.default')). + or a callable that returns an access token. Both synchronous and asynchronous + token providers are supported. Sync providers are automatically wrapped for + async compatibility. If not provided (via parameter or environment variable), + Entra ID authentication is used automatically. Defaults to the `API_KEY_ENVIRONMENT_VARIABLE` environment variable. harm_categories (Optional[list[TextCategory]]): The harm categories you want to query for as defined in azure.ai.contentsafety.models.TextCategory. If not provided, defaults to all categories. @@ -129,36 +165,25 @@ def __init__( ) # API key: use passed value, env var, or fall back to Entra ID for Azure endpoints - resolved_api_key: str | Callable[[], str] + resolved_api_key: str | Callable[[], str | Awaitable[str]] if api_key is not None and callable(api_key): - if asyncio.iscoroutinefunction(api_key): - raise ValueError( - "Async token providers are not supported by AzureContentFilterScorer. " - "Use a synchronous token provider (e.g., get_azure_token_provider) instead." - ) - # Guard against sync callables that return coroutines/awaitables (e.g., lambda: async_fn()) - test_result = api_key() - if inspect.isawaitable(test_result): - if hasattr(test_result, "close"): - test_result.close() # prevent "coroutine was never awaited" warning - raise ValueError( - "The provided token provider returns a coroutine/awaitable, which is not supported " - "by AzureContentFilterScorer. Use a synchronous token provider instead." - ) resolved_api_key = api_key else: api_key_value = default_values.get_non_required_value( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key ) - resolved_api_key = api_key_value or get_azure_token_provider("https://cognitiveservices.azure.com/.default") + resolved_api_key = api_key_value or get_azure_async_token_provider( + "https://cognitiveservices.azure.com/.default" + ) - self._api_key = resolved_api_key + # Ensure api_key is async-compatible (wrap sync token providers if needed) + self._api_key = _ensure_async_token_provider(resolved_api_key) # Create ContentSafetyClient with appropriate credential if self._endpoint is not None: if callable(self._api_key): - # Token provider - create a TokenCredential wrapper - credential = TokenProviderCredential(self._api_key) + # Token provider - create an AsyncTokenCredential wrapper + credential = AsyncTokenProviderCredential(self._api_key) self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) else: # String API key @@ -291,7 +316,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op categories=self._category_values, output_type="EightSeverityLevels", ) - text_result = self._azure_cf_client.analyze_text(text_request_options) + text_result = await self._azure_cf_client.analyze_text(text_request_options) filter_results.append(text_result) elif message_piece.converted_value_data_type == "image_path": @@ -301,7 +326,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op image_request_options = AnalyzeImageOptions( image=image_data, categories=self._category_values, output_type="FourSeverityLevels" ) - image_result = self._azure_cf_client.analyze_image(image_request_options) + image_result = await self._azure_cf_client.analyze_image(image_request_options) filter_results.append(image_result) # Collect all scores from all chunks/images diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index fedeb95e12..1c997f3326 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -6,7 +6,7 @@ from sqlalchemy import inspect -from pyrit.identifiers import AttackIdentifier +from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, SQLiteMemory from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute @@ -49,7 +49,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[AttackIdentifier] = None, + attack_identifier: Optional[ComponentIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: self.system_prompt = system_prompt diff --git a/tests/integration/score/test_azure_content_filter_integration.py b/tests/integration/score/test_azure_content_filter_integration.py index 9f7fdef20b..53d40ff320 100644 --- a/tests/integration/score/test_azure_content_filter_integration.py +++ b/tests/integration/score/test_azure_content_filter_integration.py @@ -12,6 +12,11 @@ from pyrit.memory import CentralMemory, MemoryInterface from pyrit.score import AzureContentFilterScorer +pytestmark = pytest.mark.skipif( + not os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT"), + reason="AZURE_CONTENT_SAFETY_API_ENDPOINT not configured", +) + @pytest.fixture def memory() -> Generator[MemoryInterface, None, None]: @@ -27,13 +32,6 @@ async def test_azure_content_filter_scorer_image_integration(memory) -> None: environment variables to be set. Uses a sample image from the assets folder. """ with patch.object(CentralMemory, "get_memory_instance", return_value=memory): - # Verify required environment variables are set - api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") - endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") - - if not api_key or not endpoint: - pytest.skip("Azure Content Safety credentials not configured") - scorer = AzureContentFilterScorer() image_path = HOME_PATH / "assets" / "architecture_components.png" @@ -62,13 +60,6 @@ async def test_azure_content_filter_scorer_long_text_chunking_integration(memory This verifies that the chunking and aggregation logic works correctly with the real API. """ with patch.object(CentralMemory, "get_memory_instance", return_value=memory): - # Verify required environment variables are set - api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") - endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") - - if not api_key or not endpoint: - pytest.skip("Azure Content Safety credentials not configured") - scorer = AzureContentFilterScorer() # This should be greater than the rate limit diff --git a/tests/unit/score/test_azure_content_filter.py b/tests/unit/score/test_azure_content_filter.py index 27c0cc298a..bd5d97b26e 100644 --- a/tests/unit/score/test_azure_content_filter.py +++ b/tests/unit/score/test_azure_content_filter.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. +import inspect import os from unittest.mock import AsyncMock, MagicMock, patch @@ -55,7 +56,7 @@ async def test_score_async_unsupported_data_type_returns_empty_list( @pytest.mark.asyncio async def test_score_piece_async_text(patch_central_database, text_message_piece: MessagePiece): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client scores = await scorer._score_piece_async(text_message_piece) @@ -72,7 +73,7 @@ async def test_score_piece_async_text(patch_central_database, text_message_piece @pytest.mark.asyncio async def test_score_piece_async_image(patch_central_database, image_message_piece: MessagePiece): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_image.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client # Patch _get_base64_image_data to avoid actual file IO @@ -102,25 +103,34 @@ def test_explicit_category(): assert len(scorer._harm_categories) == 1 -def test_async_callable_api_key_raises(): +def test_async_callable_api_key_accepted(): async def async_provider(): return "token" - with pytest.raises(ValueError, match="Async token providers are not supported"): - AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + scorer = AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + # Async callable should be passed through as-is + assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) -def test_sync_callable_returning_coroutine_raises(): +def test_sync_callable_returning_coroutine_accepted(): async def async_fn(): return "token" - with pytest.raises(ValueError, match="returns a coroutine/awaitable"): - AzureContentFilterScorer(api_key=lambda: async_fn(), endpoint="bar") + sync_lambda = lambda: async_fn() # noqa: E731 + # Confirm the lambda itself is NOT a coroutine function (it's sync) + assert not inspect.iscoroutinefunction(sync_lambda) + + scorer = AzureContentFilterScorer(api_key=sync_lambda, endpoint="bar") + # After init, the sync callable should be wrapped in an async function + assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) def test_sync_callable_api_key_accepted(): scorer = AzureContentFilterScorer(api_key=lambda: "token", endpoint="bar") assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) @pytest.mark.asyncio @@ -129,7 +139,7 @@ async def test_azure_content_filter_scorer_adds_to_memory(): with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -143,7 +153,7 @@ async def test_azure_content_filter_scorer_adds_to_memory(): async def test_azure_content_filter_scorer_score(patch_central_database): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -181,7 +191,7 @@ async def test_azure_content_filter_scorer_chunks_long_text(patch_central_databa with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() # Mock returns for two chunks mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -205,7 +215,7 @@ async def test_azure_content_filter_scorer_accepts_short_text(patch_central_data with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client From f71b015d90a696a764677a7aca82779142c88f5c Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Fri, 13 Mar 2026 19:23:41 -0400 Subject: [PATCH 2/6] Fix pre-commit: add Returns docstring to __aenter__ and ruff format fix --- pyrit/auth/azure_auth.py | 14 +++++++++++++- .../float_scale/azure_content_filter_scorer.py | 7 ++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index ce9aed3028..2cb471c6c7 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -107,7 +107,19 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: async def close(self) -> None: """No-op close for protocol compliance. The callable provider does not hold resources.""" - pass + + async def __aenter__(self) -> AsyncTokenProviderCredential: + """ + Enter the async context manager. + + Returns: + AsyncTokenProviderCredential: This credential instance. + """ + return self + + async def __aexit__(self, *args: Any) -> None: + """Exit the async context manager.""" + await self.close() class AzureAuth(Authenticator): diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 1ca6781919..303577416e 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -70,7 +70,12 @@ def _ensure_async_token_provider( ) async def async_token_provider() -> str: - """Async wrapper for synchronous token provider.""" + """ + Async wrapper for synchronous token provider. + + Returns: + str: The token string from the synchronous provider. + """ return api_key() # type: ignore[return-value] return async_token_provider From 0922f5554832861b5a7e3457ad66b1d6d841f4d7 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Mon, 16 Mar 2026 11:25:33 -0400 Subject: [PATCH 3/6] Address code review: isawaitable handling, token verification tests, docstring fix - Add inspect.isawaitable() check in _ensure_async_token_provider wrapper to handle sync callables that return coroutines (e.g. lambda: async_fn()) - Add _returns_token tests that await scorer._api_key() and verify the actual token value for all three provider types (async, sync-returning- coroutine, sync) - Update integration test docstring to match skip condition: endpoint is required, API key is optional (Entra ID is default auth) - Remove unnecessary type: ignore comments flagged by mypy Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../azure_content_filter_scorer.py | 5 +++- .../test_azure_content_filter_integration.py | 6 ++-- tests/unit/score/test_azure_content_filter.py | 28 +++++++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 303577416e..15f35ab3c9 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -76,7 +76,10 @@ async def async_token_provider() -> str: Returns: str: The token string from the synchronous provider. """ - return api_key() # type: ignore[return-value] + result = api_key() + if inspect.isawaitable(result): + return await result + return result return async_token_provider diff --git a/tests/integration/score/test_azure_content_filter_integration.py b/tests/integration/score/test_azure_content_filter_integration.py index 53d40ff320..b8b09e6fce 100644 --- a/tests/integration/score/test_azure_content_filter_integration.py +++ b/tests/integration/score/test_azure_content_filter_integration.py @@ -28,8 +28,10 @@ async def test_azure_content_filter_scorer_image_integration(memory) -> None: """ Integration test for Azure Content Filter Scorer with image input. - This test requires AZURE_CONTENT_SAFETY_API_KEY and AZURE_CONTENT_SAFETY_API_ENDPOINT - environment variables to be set. Uses a sample image from the assets folder. + This test requires AZURE_CONTENT_SAFETY_API_ENDPOINT to be set. + Authentication uses Entra ID by default (via `az login`). Alternatively, + set AZURE_CONTENT_SAFETY_API_KEY for API key auth. + Uses a sample image from the assets folder. """ with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer() diff --git a/tests/unit/score/test_azure_content_filter.py b/tests/unit/score/test_azure_content_filter.py index bd5d97b26e..9cd391398d 100644 --- a/tests/unit/score/test_azure_content_filter.py +++ b/tests/unit/score/test_azure_content_filter.py @@ -113,6 +113,16 @@ async def async_provider(): assert inspect.iscoroutinefunction(scorer._api_key) +@pytest.mark.asyncio +async def test_async_callable_api_key_returns_token(): + async def async_provider(): + return "token" + + scorer = AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + result = await scorer._api_key() + assert result == "token" + + def test_sync_callable_returning_coroutine_accepted(): async def async_fn(): return "token" @@ -127,12 +137,30 @@ async def async_fn(): assert inspect.iscoroutinefunction(scorer._api_key) +@pytest.mark.asyncio +async def test_sync_callable_returning_coroutine_returns_token(): + async def async_fn(): + return "token" + + sync_lambda = lambda: async_fn() # noqa: E731 + scorer = AzureContentFilterScorer(api_key=sync_lambda, endpoint="bar") + result = await scorer._api_key() + assert result == "token" + + def test_sync_callable_api_key_accepted(): scorer = AzureContentFilterScorer(api_key=lambda: "token", endpoint="bar") assert callable(scorer._api_key) assert inspect.iscoroutinefunction(scorer._api_key) +@pytest.mark.asyncio +async def test_sync_callable_api_key_returns_token(): + scorer = AzureContentFilterScorer(api_key=lambda: "token", endpoint="bar") + result = await scorer._api_key() + assert result == "token" + + @pytest.mark.asyncio async def test_azure_content_filter_scorer_adds_to_memory(): memory = MagicMock(MemoryInterface) From ea6f54876e8df570e99ad0d4ece699b1a0c85614 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Mon, 16 Mar 2026 11:47:18 -0400 Subject: [PATCH 4/6] Downgrade sync token provider log from INFO to DEBUG The wrapping of sync providers is a normal init-time path, not worth INFO-level noise. DEBUG keeps the diagnostic available when needed. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/score/float_scale/azure_content_filter_scorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 15f35ab3c9..42109a3680 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -64,7 +64,7 @@ def _ensure_async_token_provider( return api_key # Wrap synchronous token provider in async function - logger.info( + logger.debug( "Detected synchronous token provider." " Automatically wrapping in async function for compatibility with async ContentSafetyClient." ) From 477132a0f8c1505ecd8ad56ad93d3fbda6a7619b Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Mon, 16 Mar 2026 15:33:46 -0400 Subject: [PATCH 5/6] Address review: remove skipif, add API key auth tests with assert - Remove module-level pytestmark skipif for AZURE_CONTENT_SAFETY_API_ENDPOINT - Add integration tests for image and text scoring with explicit API key auth - Use assert instead of pytest.skip so tests fail visibly when key is not set Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../test_azure_content_filter_integration.py | 64 +++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/tests/integration/score/test_azure_content_filter_integration.py b/tests/integration/score/test_azure_content_filter_integration.py index b8b09e6fce..b0b0d1f9cd 100644 --- a/tests/integration/score/test_azure_content_filter_integration.py +++ b/tests/integration/score/test_azure_content_filter_integration.py @@ -12,11 +12,6 @@ from pyrit.memory import CentralMemory, MemoryInterface from pyrit.score import AzureContentFilterScorer -pytestmark = pytest.mark.skipif( - not os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT"), - reason="AZURE_CONTENT_SAFETY_API_ENDPOINT not configured", -) - @pytest.fixture def memory() -> Generator[MemoryInterface, None, None]: @@ -84,3 +79,62 @@ async def test_azure_content_filter_scorer_long_text_chunking_integration(memory # Long benign text should still have low scores max_score = max(float(s.score_value) for s in scores) assert max_score > 0, "text should have > 0 score" + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_image_with_api_key(memory) -> None: + """ + Integration test for Azure Content Filter Scorer image scoring with explicit API key auth. + + This test requires AZURE_CONTENT_SAFETY_API_KEY to be set. + """ + api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") + assert api_key, "AZURE_CONTENT_SAFETY_API_KEY must be set to test API key auth" + + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + scorer = AzureContentFilterScorer(api_key=api_key) + + image_path = HOME_PATH / "assets" / "architecture_components.png" + scores = await scorer.score_image_async(image_path=str(image_path)) + + assert len(scores) > 0, "Expected scores for harm categories" + + for score in scores: + assert score.score_type == "float_scale" + assert score.score_category is not None + assert 0.0 <= float(score.score_value) <= 1.0 + assert "azure_severity" in score.score_metadata + + max_score = max(float(s.score_value) for s in scores) + assert max_score < 0.5, "Architecture diagram should have low harm scores" + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_text_with_api_key(memory) -> None: + """ + Integration test for Azure Content Filter Scorer text scoring with explicit API key auth. + + This test requires AZURE_CONTENT_SAFETY_API_KEY to be set. + """ + api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") + assert api_key, "AZURE_CONTENT_SAFETY_API_KEY must be set to test API key auth" + + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + scorer = AzureContentFilterScorer(api_key=api_key) + + base_text = "I hate you " * 10000 + + assert len(base_text) > AzureContentFilterScorer.MAX_TEXT_LENGTH + + scores = await scorer.score_text_async(text=base_text) + + assert len(scores) > 0, "Expected aggregated scores for harm categories" + + for score in scores: + assert score.score_type == "float_scale" + assert score.score_category is not None + assert 0.0 <= float(score.score_value) <= 1.0 + assert "azure_severity" in score.score_metadata + + max_score = max(float(s.score_value) for s in scores) + assert max_score > 0, "text should have > 0 score" From 0d024cf5c53eaae29ee4cd1386c8e6eb47faf349 Mon Sep 17 00:00:00 2001 From: Adrian Gavrila Date: Mon, 16 Mar 2026 17:51:59 -0400 Subject: [PATCH 6/6] Refactor: extract ensure_async_token_provider into pyrit.auth.azure_auth Move _ensure_async_token_provider from openai_target.py and azure_content_filter_scorer.py into a shared public helper in pyrit/auth/azure_auth.py. Uses the more robust implementation (inspect.iscoroutinefunction + isawaitable handling) and logs at DEBUG level. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auth/__init__.py | 2 + pyrit/auth/azure_auth.py | 43 +++++++++++++++++ pyrit/prompt_target/openai/openai_target.py | 45 +---------------- .../azure_content_filter_scorer.py | 48 +------------------ tests/unit/target/test_openai_target_auth.py | 15 +++--- .../target/test_token_provider_wrapping.py | 40 ++++++++-------- 6 files changed, 77 insertions(+), 116 deletions(-) diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 02cd90b1dd..3b17fcef61 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -10,6 +10,7 @@ AsyncTokenProviderCredential, AzureAuth, TokenProviderCredential, + ensure_async_token_provider, get_azure_async_token_provider, get_azure_openai_auth, get_azure_token_provider, @@ -27,6 +28,7 @@ "CopilotAuthenticator", "ManualCopilotAuthenticator", "TokenProviderCredential", + "ensure_async_token_provider", "get_azure_token_provider", "get_azure_async_token_provider", "get_default_azure_scope", diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 193edd72b0..4149749e45 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -121,6 +121,49 @@ async def __aexit__(self, *args: Any) -> None: await self.close() +def ensure_async_token_provider( + api_key: str | Callable[[], str | Awaitable[str]] | None, +) -> str | Callable[[], Awaitable[str]] | None: + """ + Ensure the api_key is either a string or an async callable. + + If a synchronous callable token provider is provided, it's automatically wrapped + in an async function to make it compatible with async Azure SDK clients. + + Args: + api_key: Either a string API key or a callable that returns a token (sync or async). + + Returns: + Either a string API key or an async callable that returns a token. + """ + if api_key is None or isinstance(api_key, str) or not callable(api_key): + return api_key + + # Check if the callable is already async + if inspect.iscoroutinefunction(api_key): + return api_key + + # Wrap synchronous token provider in async function + logger.debug( + "Detected synchronous token provider." + " Automatically wrapping in async function for compatibility with async client." + ) + + async def async_token_provider() -> str: + """ + Async wrapper for synchronous token provider. + + Returns: + str: The token string from the synchronous provider. + """ + result = api_key() + if inspect.isawaitable(result): + return await result + return result + + return async_token_provider + + class AzureAuth(Authenticator): """ Azure CLI Authentication. diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 0128991e3f..0549cd8f62 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import asyncio import json import logging import re @@ -23,7 +22,7 @@ AuthenticationError, ) -from pyrit.auth import get_azure_openai_auth +from pyrit.auth import ensure_async_token_provider, get_azure_openai_auth from pyrit.common import default_values from pyrit.exceptions.exception_classes import ( RateLimitException, @@ -41,46 +40,6 @@ logger = logging.getLogger(__name__) -def _ensure_async_token_provider( - api_key: Optional[str | Callable[[], str | Awaitable[str]]], -) -> Optional[str | Callable[[], Awaitable[str]]]: - """ - Ensure the api_key is either a string or an async callable. - - If a synchronous callable token provider is provided, it's automatically wrapped - in an async function to make it compatible with AsyncOpenAI. - - Args: - api_key: Either a string API key or a callable that returns a token (sync or async). - - Returns: - Either a string API key or an async callable that returns a token. - """ - if api_key is None or isinstance(api_key, str) or not callable(api_key): - return api_key - - # Check if the callable is already async - if asyncio.iscoroutinefunction(api_key): - return api_key - - # Wrap synchronous token provider in async function - logger.info( - "Detected synchronous token provider." - " Automatically wrapping in async function for compatibility with AsyncOpenAI." - ) - - async def async_token_provider() -> str: - """ - Async wrapper for synchronous token provider. - - Returns: - str: The token string from the synchronous provider. - """ - return api_key() # type: ignore[return-value] - - return async_token_provider - - class OpenAITarget(PromptTarget): """ Abstract base class for OpenAI-based prompt targets. @@ -198,7 +157,7 @@ def __init__( ) # Ensure api_key is async-compatible (wrap sync token providers if needed) - self._api_key = _ensure_async_token_provider(resolved_api_key) + self._api_key = ensure_async_token_provider(resolved_api_key) self._initialize_openai_client() diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 42109a3680..16aa3d75ab 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import base64 -import inspect import logging from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Optional @@ -18,7 +17,7 @@ ) from azure.core.credentials import AzureKeyCredential -from pyrit.auth import AsyncTokenProviderCredential, get_azure_async_token_provider +from pyrit.auth import AsyncTokenProviderCredential, ensure_async_token_provider, get_azure_async_token_provider from pyrit.common import default_values from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( @@ -41,49 +40,6 @@ logger = logging.getLogger(__name__) -def _ensure_async_token_provider( - api_key: str | Callable[[], str | Awaitable[str]] | None, -) -> str | Callable[[], Awaitable[str]] | None: - """ - Ensure the api_key is either a string or an async callable. - - If a synchronous callable token provider is provided, it's automatically wrapped - in an async function to make it compatible with the async ContentSafetyClient. - - Args: - api_key: Either a string API key or a callable that returns a token (sync or async). - - Returns: - Either a string API key or an async callable that returns a token. - """ - if api_key is None or isinstance(api_key, str) or not callable(api_key): - return api_key - - # Check if the callable is already async - if inspect.iscoroutinefunction(api_key): - return api_key - - # Wrap synchronous token provider in async function - logger.debug( - "Detected synchronous token provider." - " Automatically wrapping in async function for compatibility with async ContentSafetyClient." - ) - - async def async_token_provider() -> str: - """ - Async wrapper for synchronous token provider. - - Returns: - str: The token string from the synchronous provider. - """ - result = api_key() - if inspect.isawaitable(result): - return await result - return result - - return async_token_provider - - class AzureContentFilterScorer(FloatScaleScorer): """ A scorer that uses Azure Content Safety API to evaluate text and images for harmful content. @@ -185,7 +141,7 @@ def __init__( ) # Ensure api_key is async-compatible (wrap sync token providers if needed) - self._api_key = _ensure_async_token_provider(resolved_api_key) + self._api_key = ensure_async_token_provider(resolved_api_key) # Create ContentSafetyClient with appropriate credential if self._endpoint is not None: diff --git a/tests/unit/target/test_openai_target_auth.py b/tests/unit/target/test_openai_target_auth.py index 2045ae6e20..efe00a0caf 100644 --- a/tests/unit/target/test_openai_target_auth.py +++ b/tests/unit/target/test_openai_target_auth.py @@ -9,7 +9,8 @@ import pytest -from pyrit.prompt_target.openai.openai_target import OpenAITarget, _ensure_async_token_provider +from pyrit.auth import ensure_async_token_provider +from pyrit.prompt_target.openai.openai_target import OpenAITarget class _ConcreteOpenAITarget(OpenAITarget): @@ -126,30 +127,30 @@ def test_param_api_key_takes_precedence_over_env_var(self): class TestEnsureAsyncTokenProvider: - """Tests for the _ensure_async_token_provider helper function.""" + """Tests for the ensure_async_token_provider helper function.""" def test_none_returns_none(self): - assert _ensure_async_token_provider(None) is None + assert ensure_async_token_provider(None) is None def test_string_returns_string(self): - assert _ensure_async_token_provider("my-key") == "my-key" + assert ensure_async_token_provider("my-key") == "my-key" def test_async_callable_returned_as_is(self): async def provider() -> str: return "token" - result = _ensure_async_token_provider(provider) + result = ensure_async_token_provider(provider) assert result is provider def test_sync_callable_wrapped_to_async(self): def provider() -> str: return "sync-token" - result = _ensure_async_token_provider(provider) + result = ensure_async_token_provider(provider) assert asyncio.iscoroutinefunction(result) assert asyncio.run(result()) == "sync-token" def test_non_callable_non_string_returned_as_is(self): # Edge case: something that's not a string and not callable - result = _ensure_async_token_provider(42) # type: ignore[arg-type] + result = ensure_async_token_provider(42) # type: ignore[arg-type] assert result == 42 diff --git a/tests/unit/target/test_token_provider_wrapping.py b/tests/unit/target/test_token_provider_wrapping.py index 8ca874ce0d..9f5657fed2 100644 --- a/tests/unit/target/test_token_provider_wrapping.py +++ b/tests/unit/target/test_token_provider_wrapping.py @@ -6,7 +6,7 @@ import pytest -from pyrit.prompt_target.openai.openai_target import _ensure_async_token_provider +from pyrit.auth import ensure_async_token_provider class TestTokenProviderWrapping: @@ -15,13 +15,13 @@ class TestTokenProviderWrapping: def test_string_api_key_unchanged(self): """Test that string API keys are returned unchanged.""" api_key = "sk-test-key-12345" - result = _ensure_async_token_provider(api_key) + result = ensure_async_token_provider(api_key) assert result == api_key assert isinstance(result, str) def test_none_api_key_unchanged(self): """Test that None is returned unchanged.""" - result = _ensure_async_token_provider(None) + result = ensure_async_token_provider(None) assert result is None def test_async_token_provider_unchanged(self): @@ -30,7 +30,7 @@ def test_async_token_provider_unchanged(self): async def async_token_provider(): return "async-token" - result = _ensure_async_token_provider(async_token_provider) + result = ensure_async_token_provider(async_token_provider) assert result is async_token_provider assert asyncio.iscoroutinefunction(result) @@ -40,7 +40,7 @@ def test_sync_token_provider_wrapped(self): def sync_token_provider(): return "sync-token" - result = _ensure_async_token_provider(sync_token_provider) + result = ensure_async_token_provider(sync_token_provider) # Should return a different callable (the wrapper) assert result is not sync_token_provider @@ -54,7 +54,7 @@ async def test_wrapped_sync_provider_returns_correct_token(self): def sync_token_provider(): return "my-sync-token" - wrapped = _ensure_async_token_provider(sync_token_provider) + wrapped = ensure_async_token_provider(sync_token_provider) # Call the wrapped provider token = await wrapped() @@ -67,7 +67,7 @@ async def test_async_provider_returns_correct_token(self): async def async_token_provider(): return "my-async-token" - result = _ensure_async_token_provider(async_token_provider) + result = ensure_async_token_provider(async_token_provider) # Should be the same function assert result is async_token_provider @@ -86,7 +86,7 @@ def sync_token_provider(): call_count += 1 return f"token-{call_count}" - wrapped = _ensure_async_token_provider(sync_token_provider) + wrapped = ensure_async_token_provider(sync_token_provider) # Call multiple times token1 = await wrapped() @@ -97,15 +97,15 @@ def sync_token_provider(): assert call_count == 2 def test_sync_provider_wrapping_logs_info(self): - """Test that wrapping a sync provider logs an info message.""" + """Test that wrapping a sync provider logs a debug message.""" def sync_token_provider(): return "token" - with patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger: - _ensure_async_token_provider(sync_token_provider) - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0][0] + with patch("pyrit.auth.azure_auth.logger") as mock_logger: + ensure_async_token_provider(sync_token_provider) + mock_logger.debug.assert_called_once() + call_args = mock_logger.debug.call_args[0][0] assert "synchronous token provider" in call_args.lower() assert "wrapping" in call_args.lower() @@ -124,7 +124,7 @@ def sync_token_provider(): with ( patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai, - patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger, + patch("pyrit.auth.azure_auth.logger") as mock_logger, ): mock_client = AsyncMock() mock_openai.return_value = mock_client @@ -135,14 +135,14 @@ def sync_token_provider(): api_key=sync_token_provider, ) - # Verify that info log was called about wrapping - mock_logger.info.assert_called() + # Verify that debug log was called about wrapping + mock_logger.debug.assert_called() info_call_found = False - for call in mock_logger.info.call_args_list: + for call in mock_logger.debug.call_args_list: if "synchronous token provider" in str(call).lower(): info_call_found = True break - assert info_call_found, "Expected info log about wrapping sync token provider" + assert info_call_found, "Expected debug log about wrapping sync token provider" # Verify AsyncOpenAI was initialized mock_openai.assert_called_once() @@ -223,7 +223,7 @@ def mock_sync_bearer_token_provider(): with ( patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai, - patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger, + patch("pyrit.auth.azure_auth.logger") as mock_logger, ): mock_client = AsyncMock() mock_openai.return_value = mock_client @@ -235,7 +235,7 @@ def mock_sync_bearer_token_provider(): ) # Verify that sync provider was wrapped - mock_logger.info.assert_called() + mock_logger.debug.assert_called() # Get the wrapped api_key call_kwargs = mock_openai.call_args[1]