diff --git a/python/packages/purview/agent_framework_purview/_processor.py b/python/packages/purview/agent_framework_purview/_processor.py index 0e197b78c0..fb115783f5 100644 --- a/python/packages/purview/agent_framework_purview/_processor.py +++ b/python/packages/purview/agent_framework_purview/_processor.py @@ -211,9 +211,8 @@ async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> Proce cache_key = create_protection_scopes_cache_key(ps_req) cached_ps_resp = await self._cache.get(cache_key) - if cached_ps_resp is not None: - if isinstance(cached_ps_resp, ProtectionScopesResponse): - ps_resp = cached_ps_resp + if cached_ps_resp is not None and isinstance(cached_ps_resp, ProtectionScopesResponse): + ps_resp = cached_ps_resp else: try: ps_resp = await self._client.get_protection_scopes(ps_req) diff --git a/python/packages/purview/tests/test_cache.py b/python/packages/purview/tests/test_cache.py index 2089d9e2e7..1892842d42 100644 --- a/python/packages/purview/tests/test_cache.py +++ b/python/packages/purview/tests/test_cache.py @@ -119,6 +119,25 @@ class CustomObject: assert result == obj + async def test_estimate_size_conservative_fallback_when_all_size_methods_fail(self, monkeypatch) -> None: + """Test that the cache returns a conservative size estimate when all strategies fail.""" + cache = InMemoryCacheProvider() + + class BadString: + def __str__(self) -> str: + raise RuntimeError("boom") + + def raise_getsizeof(_: object) -> int: + raise RuntimeError("no sizeof") + + monkeypatch.setattr("agent_framework_purview._cache.sys.getsizeof", raise_getsizeof) + + # Arrange/Act + size = cache._estimate_size(BadString()) + + # Assert + assert size == 1024 + async def test_cache_multiple_updates(self) -> None: """Test that updating a key multiple times maintains correct size tracking.""" cache = InMemoryCacheProvider(max_size_bytes=1000) diff --git a/python/packages/purview/tests/test_chat_middleware.py b/python/packages/purview/tests/test_chat_middleware.py index 8d414babb9..3f9595e721 100644 --- a/python/packages/purview/tests/test_chat_middleware.py +++ b/python/packages/purview/tests/test_chat_middleware.py @@ -204,6 +204,39 @@ async def mock_next(ctx: ChatContext) -> None: with pytest.raises(PurviewPaymentRequiredError): await middleware.process(context, mock_next) + async def test_chat_middleware_handles_payment_required_post_check(self, mock_credential: AsyncMock) -> None: + """Test that 402 in post-check is raised when ignore_payment_required=False.""" + from agent_framework_purview._exceptions import PurviewPaymentRequiredError + + settings = PurviewSettings(app_name="Test App", ignore_payment_required=False) + middleware = PurviewChatPolicyMiddleware(mock_credential, settings) + + chat_client = DummyChatClient() + chat_options = MagicMock() + chat_options.model = "test-model" + context = ChatContext( + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options + ) + + call_count = 0 + + async def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (False, "user-123") + raise PurviewPaymentRequiredError("Payment required") + + with patch.object(middleware._processor, "process_messages", side_effect=side_effect): + + async def mock_next(ctx: ChatContext) -> None: + result = MagicMock() + result.messages = [ChatMessage(role=Role.ASSISTANT, text="OK")] + ctx.result = result + + with pytest.raises(PurviewPaymentRequiredError): + await middleware.process(context, mock_next) + async def test_chat_middleware_ignores_payment_required_when_configured(self, mock_credential: AsyncMock) -> None: """Test that 402 is ignored when ignore_payment_required=True.""" from agent_framework_purview._exceptions import PurviewPaymentRequiredError @@ -274,3 +307,58 @@ async def mock_next(ctx: ChatContext) -> None: await middleware.process(context, mock_next) # Next should have been called assert context.result is not None + + async def test_chat_middleware_raises_on_pre_check_exception_when_ignore_exceptions_false( + self, mock_credential: AsyncMock + ) -> None: + """Test that exceptions are propagated by default when ignore_exceptions=False.""" + settings = PurviewSettings(app_name="Test App", ignore_exceptions=False) + middleware = PurviewChatPolicyMiddleware(mock_credential, settings) + + chat_client = DummyChatClient() + chat_options = MagicMock() + chat_options.model = "test-model" + context = ChatContext( + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options + ) + + with patch.object(middleware._processor, "process_messages", side_effect=ValueError("boom")): + + async def mock_next(_: ChatContext) -> None: + raise AssertionError("next should not be called") + + with pytest.raises(ValueError, match="boom"): + await middleware.process(context, mock_next) + + async def test_chat_middleware_raises_on_post_check_exception_when_ignore_exceptions_false( + self, mock_credential: AsyncMock + ) -> None: + """Test that post-check exceptions are propagated by default.""" + settings = PurviewSettings(app_name="Test App", ignore_exceptions=False) + middleware = PurviewChatPolicyMiddleware(mock_credential, settings) + + chat_client = DummyChatClient() + chat_options = MagicMock() + chat_options.model = "test-model" + context = ChatContext( + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options + ) + + call_count = 0 + + async def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (False, "user-123") + raise ValueError("post") + + with patch.object(middleware._processor, "process_messages", side_effect=side_effect): + + async def mock_next(ctx: ChatContext) -> None: + result = MagicMock() + result.messages = [ChatMessage(role=Role.ASSISTANT, text="OK")] + ctx.result = result + + with pytest.raises(ValueError, match="post"): + await middleware.process(context, mock_next) diff --git a/python/packages/purview/tests/test_client.py b/python/packages/purview/tests/test_client.py index 7953c16b77..b740b3f09c 100644 --- a/python/packages/purview/tests/test_client.py +++ b/python/packages/purview/tests/test_client.py @@ -2,6 +2,7 @@ """Tests for Purview client.""" +from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -18,6 +19,8 @@ PurviewServiceError, ) from agent_framework_purview._models import ( + ContentActivitiesRequest, + ContentActivitiesResponse, PolicyLocation, ProcessContentRequest, ProtectionScopesRequest, @@ -47,7 +50,9 @@ def settings(self) -> PurviewSettings: return PurviewSettings(app_name="Test App", tenant_id="test-tenant", default_user_id="test-user") @pytest.fixture - async def client(self, mock_credential: MagicMock, settings: PurviewSettings) -> PurviewClient: + async def client( + self, mock_credential: MagicMock, settings: PurviewSettings + ) -> AsyncGenerator[PurviewClient, None]: """Create a PurviewClient with mock credential.""" client = PurviewClient(mock_credential, settings, timeout=10.0) yield client @@ -185,6 +190,215 @@ async def test_get_protection_scopes_success(self, client: PurviewClient) -> Non assert response.scope_identifier == "scope-123" assert response.scopes == [] + async def test_get_protection_scopes_uses_etag_header_when_present(self, client: PurviewClient) -> None: + """Test that get_protection_scopes prefers the HTTP ETag header when present.""" + from agent_framework_purview._models import ProtectionScopesResponse + + location = PolicyLocation(**{"@odata.type": "microsoft.graph.policyLocationApplication", "value": "app-id"}) + request = ProtectionScopesRequest( + user_id="user-123", tenant_id="tenant-456", locations=[location], correlation_id="corr-789" + ) + + response_obj = ProtectionScopesResponse(**{"scopeIdentifier": "scope-from-body", "value": []}) + + with patch.object( + client, + "_post", + return_value=(response_obj, {"etag": '"etag-from-header"'}), + ): + response = await client.get_protection_scopes(request) + + assert response.scope_identifier == "etag-from-header" + + async def test_post_402_returns_empty_response_when_ignore_payment_required_enabled( + self, mock_credential: MagicMock + ) -> None: + """Test that 402 is suppressed when ignore_payment_required=True.""" + from agent_framework_purview._models import ProcessContentResponse + + settings = PurviewSettings(app_name="Test App", ignore_payment_required=True) + client = PurviewClient(mock_credential, settings) + + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + + resp = httpx.Response(402, text="Payment required", request=httpx.Request("POST", "http://test")) + + with patch.object(client._client, "post", return_value=resp): + result = await client._post("http://test", request, ProcessContentResponse, token="fake-token") + + assert isinstance(result, ProcessContentResponse) + await client.close() + + async def test_post_sets_request_and_response_correlation_id(self, client: PurviewClient) -> None: + """Test that correlation_id is injected into request headers and hydrated from response headers.""" + from agent_framework_purview._models import ProcessContentResponse + + # correlation_id is optional and should be auto-generated when empty + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request.correlation_id = "" # force auto-generation branch + + captured_headers: dict[str, str] = {} + + async def fake_post(url: str, json=None, headers=None): + nonlocal captured_headers + captured_headers = dict(headers or {}) + return httpx.Response( + 200, + json={"id": "resp-1", "protectionScopeState": "notModified"}, + headers={"client-request-id": "corr-from-response"}, + request=httpx.Request("POST", url), + ) + + with patch.object(client._client, "post", side_effect=fake_post): + result_obj, result_headers = await client._post( + "http://test", + request, + ProcessContentResponse, + token="fake-token", + return_response=True, + ) + + assert "client-request-id" in captured_headers + assert captured_headers["client-request-id"] + assert result_headers["client-request-id"] == "corr-from-response" + assert result_obj.correlation_id == "corr-from-response" + + async def test_process_content_402_returns_empty_when_ignored(self, mock_credential: MagicMock) -> None: + """Test that process_content returns an empty response (non-tuple path) when 402 is ignored.""" + from agent_framework_purview._models import ProcessContentResponse + + settings = PurviewSettings(app_name="Test App", ignore_payment_required=True) + client = PurviewClient(mock_credential, settings) + + req = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 402 + mock_response.text = "Payment required" + + with patch.object(client._client, "post", return_value=mock_response): + response = await client.process_content(req) + + assert isinstance(response, ProcessContentResponse) + await client.close() + + async def test_post_sets_correlation_id_attribute_on_recording_span(self, client: PurviewClient) -> None: + """Test that correlation_id is added to the active span when recording is enabled.""" + from agent_framework_purview._models import ProcessContentResponse + + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request.correlation_id = "corr-123" + + class RecordingSpan: + def __init__(self) -> None: + self.attributes: dict[str, str] = {} + + def is_recording(self) -> bool: + return True + + def set_attribute(self, key: str, value: str) -> None: + self.attributes[key] = value + + span = RecordingSpan() + + with ( + patch("agent_framework_purview._client.trace.get_current_span", return_value=span), + patch.object( + client._client, + "post", + return_value=httpx.Response( + 200, + json={"id": "resp-1", "protectionScopeState": "notModified"}, + headers={}, + request=httpx.Request("POST", "http://test"), + ), + ), + ): + await client._post("http://test", request, ProcessContentResponse, token="fake-token") + + assert span.attributes["correlation_id"] == "corr-123" + + async def test_post_uses_constructor_when_response_type_has_no_model_validate(self, client: PurviewClient) -> None: + """Test that _post falls back to the response type constructor when model_validate is absent.""" + + class DummyResponse: + def __init__(self, **data): + self.data = data + + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request.correlation_id = "corr-123" + + with patch.object( + client._client, + "post", + return_value=httpx.Response( + 200, + json={"hello": "world"}, + headers={}, + request=httpx.Request("POST", "http://test"), + ), + ): + result = await client._post("http://test", request, DummyResponse, token="fake-token") + + assert isinstance(result, DummyResponse) + assert result.data == {"hello": "world"} + + async def test_send_content_activities_success(self, client: PurviewClient, content_to_process_factory) -> None: + """Test send_content_activities success path.""" + request = ContentActivitiesRequest( + user_id="user-123", + tenant_id="tenant-456", + content_to_process=content_to_process_factory("hello"), + correlation_id="corr-1", + ) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.headers = {} + mock_response.json.return_value = {"error": None} + + with patch.object(client._client, "post", return_value=mock_response): + resp = await client.send_content_activities(request) + + assert isinstance(resp, ContentActivitiesResponse) + + async def test_post_handles_invalid_json_response_body(self, client: PurviewClient) -> None: + """Test that invalid JSON bodies fall back to an empty dict.""" + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request.correlation_id = "corr-123" + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.headers = {} + mock_response.json.side_effect = ValueError("not json") + + with patch.object(client._client, "post", return_value=mock_response): + result = await client._post("http://test", request, ContentActivitiesResponse, token="fake-token") + + assert isinstance(result, ContentActivitiesResponse) + + async def test_post_deserialization_failure_raises_purview_service_error(self, client: PurviewClient) -> None: + """Test that response deserialization errors are wrapped as PurviewServiceError.""" + + class BadResponseType: + @classmethod + def model_validate(cls, value): + raise RuntimeError("boom") + + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request.correlation_id = "corr-123" + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.headers = {} + mock_response.json.return_value = {"any": "data"} + + with ( + patch.object(client._client, "post", return_value=mock_response), + pytest.raises(PurviewServiceError, match="Failed to deserialize Purview response"), + ): + await client._post("http://test", request, BadResponseType, token="fake-token") + async def test_client_close(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None: """Test client properly closes HTTP client.""" client = PurviewClient(mock_credential, settings) diff --git a/python/packages/purview/tests/test_middleware.py b/python/packages/purview/tests/test_middleware.py index 9426bc66af..b973e8ea34 100644 --- a/python/packages/purview/tests/test_middleware.py +++ b/python/packages/purview/tests/test_middleware.py @@ -153,6 +153,92 @@ async def mock_next(ctx: AgentRunContext) -> None: for call in mock_process.call_args_list: assert call[0][1] == Activity.UPLOAD_TEXT + async def test_middleware_streaming_skips_post_check( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test that streaming results skip post-check evaluation.""" + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")]) + context.is_streaming = True + + with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="streaming")]) + + await middleware.process(context, mock_next) + + assert mock_proc.call_count == 1 + + async def test_middleware_payment_required_in_pre_check_raises_by_default( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test that 402 in pre-check is raised when ignore_payment_required=False.""" + from agent_framework_purview._exceptions import PurviewPaymentRequiredError + + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")]) + + with patch.object( + middleware._processor, + "process_messages", + side_effect=PurviewPaymentRequiredError("Payment required"), + ): + + async def mock_next(_: AgentRunContext) -> None: + raise AssertionError("next should not be called") + + with pytest.raises(PurviewPaymentRequiredError): + await middleware.process(context, mock_next) + + async def test_middleware_payment_required_in_post_check_raises_by_default( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test that 402 in post-check is raised when ignore_payment_required=False.""" + from agent_framework_purview._exceptions import PurviewPaymentRequiredError + + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")]) + + call_count = 0 + + async def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (False, "user-123") + raise PurviewPaymentRequiredError("Payment required") + + with patch.object(middleware._processor, "process_messages", side_effect=side_effect): + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="OK")]) + + with pytest.raises(PurviewPaymentRequiredError): + await middleware.process(context, mock_next) + + async def test_middleware_post_check_exception_raises_when_ignore_exceptions_false( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test that post-check exceptions are propagated when ignore_exceptions=False.""" + middleware._settings.ignore_exceptions = False + + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")]) + + call_count = 0 + + async def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (False, "user-123") + raise ValueError("Post-check blew up") + + with patch.object(middleware._processor, "process_messages", side_effect=side_effect): + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="OK")]) + + with pytest.raises(ValueError, match="Post-check blew up"): + await middleware.process(context, mock_next) + async def test_middleware_handles_pre_check_exception( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: diff --git a/python/packages/purview/tests/test_processor.py b/python/packages/purview/tests/test_processor.py index c517da2459..11f48ed199 100644 --- a/python/packages/purview/tests/test_processor.py +++ b/python/packages/purview/tests/test_processor.py @@ -242,6 +242,83 @@ async def test_process_with_scopes_calls_client_methods( # The response should have id=204 (No Content) when no scopes apply assert response.id == "204" + async def test_process_with_scopes_ignores_unexpected_cached_value_type( + self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test that a corrupted cache entry does not crash processing.""" + from agent_framework_purview._models import ( + ExecutionMode, + PolicyLocation, + PolicyScope, + ProcessContentResponse, + ProtectionScopeActivities, + ProtectionScopesResponse, + ) + + request = process_content_request_factory() + + # Return a valid, inline scope so we stay on the normal (non-background) path. + scope_location = PolicyLocation(**{ + "@odata.type": "microsoft.graph.policyLocationApplication", + "value": "app-id", + }) + scope = PolicyScope(**{ + "activities": ProtectionScopeActivities.UPLOAD_TEXT, + "locations": [scope_location], + "execution_mode": ExecutionMode.EVALUATE_INLINE, + }) + mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": [scope]})) + mock_client.process_content = AsyncMock( + return_value=ProcessContentResponse(**{"id": "ok", "protectionScopeState": "notModified"}) + ) + + # First cache read is the tenant payment key (None). Second is the scopes cache (corrupt value). + processor._cache.get = AsyncMock(side_effect=[None, "corrupt-value"]) # type: ignore[method-assign] + processor._cache.set = AsyncMock() # type: ignore[method-assign] + + response = await processor._process_with_scopes(request) + + assert response.id == "ok" + mock_client.get_protection_scopes.assert_called_once() + mock_client.process_content.assert_called_once() + + async def test_process_with_scopes_uses_tenant_payment_exception_cache( + self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test that a cached 402 exception short-circuits all subsequent requests for the tenant.""" + from agent_framework_purview._exceptions import PurviewPaymentRequiredError + + request = process_content_request_factory() + + processor._cache.get = AsyncMock(return_value=PurviewPaymentRequiredError("Payment required")) # type: ignore[method-assign] + + with pytest.raises(PurviewPaymentRequiredError): + await processor._process_with_scopes(request) + + mock_client.get_protection_scopes.assert_not_called() + + async def test_process_content_background_retries_on_modified_state( + self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test offline background processing invalidates cache and retries when scope state changes.""" + from agent_framework_purview._models import ProcessContentResponse + + request = process_content_request_factory() + request.scope_identifier = "etag-1" + + mock_client.process_content = AsyncMock( + side_effect=[ + ProcessContentResponse(**{"id": "r1", "protectionScopeState": "modified"}), + ProcessContentResponse(**{"id": "r2", "protectionScopeState": "notModified"}), + ] + ) + processor._cache.remove = AsyncMock() # type: ignore[method-assign] + + await processor._process_content_background(request, cache_key="purview:protection_scopes:abc") + + processor._cache.remove.assert_called_once_with("purview:protection_scopes:abc") + assert mock_client.process_content.call_count == 2 + async def test_map_messages_with_user_id_in_additional_properties(self, mock_client: AsyncMock) -> None: """Test user_id extraction from message additional_properties.""" settings = PurviewSettings(