|
35 | 35 | handle_token_response_scopes, |
36 | 36 | is_valid_client_metadata_url, |
37 | 37 | should_use_client_metadata_url, |
| 38 | + validate_authorization_response_iss, |
| 39 | + validate_metadata_issuer, |
38 | 40 | ) |
39 | 41 | from mcp.client.streamable_http import MCP_PROTOCOL_VERSION |
40 | 42 | from mcp.shared.auth import ( |
| 43 | + AuthorizationCodeResult, |
41 | 44 | OAuthClientInformationFull, |
42 | 45 | OAuthClientMetadata, |
43 | 46 | OAuthMetadata, |
@@ -97,7 +100,7 @@ class OAuthContext: |
97 | 100 | client_metadata: OAuthClientMetadata |
98 | 101 | storage: TokenStorage |
99 | 102 | redirect_handler: Callable[[str], Awaitable[None]] | None |
100 | | - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None |
| 103 | + callback_handler: Callable[[], Awaitable[AuthorizationCodeResult]] | None |
101 | 104 | timeout: float = 300.0 |
102 | 105 | client_metadata_url: str | None = None |
103 | 106 |
|
@@ -227,7 +230,7 @@ def __init__( |
227 | 230 | client_metadata: OAuthClientMetadata, |
228 | 231 | storage: TokenStorage, |
229 | 232 | redirect_handler: Callable[[str], Awaitable[None]] | None = None, |
230 | | - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, |
| 233 | + callback_handler: Callable[[], Awaitable[AuthorizationCodeResult]] | None = None, |
231 | 234 | timeout: float = 300.0, |
232 | 235 | client_metadata_url: str | None = None, |
233 | 236 | validate_resource_url: Callable[[str, str | None], Awaitable[None]] | None = None, |
@@ -356,16 +359,19 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: |
356 | 359 | await self.context.redirect_handler(authorization_url) |
357 | 360 |
|
358 | 361 | # Wait for callback |
359 | | - auth_code, returned_state = await self.context.callback_handler() |
| 362 | + result = await self.context.callback_handler() |
360 | 363 |
|
361 | | - if returned_state is None or not secrets.compare_digest(returned_state, state): |
362 | | - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") |
| 364 | + if result.state is None or not secrets.compare_digest(result.state, state): |
| 365 | + raise OAuthFlowError(f"State parameter mismatch: {result.state} != {state}") |
363 | 366 |
|
364 | | - if not auth_code: |
| 367 | + # RFC 9207: validate the authorization-response issuer |
| 368 | + validate_authorization_response_iss(result.iss, self.context.oauth_metadata) |
| 369 | + |
| 370 | + if not result.code: |
365 | 371 | raise OAuthFlowError("No authorization code received") |
366 | 372 |
|
367 | 373 | # Return auth code and code verifier for token exchange |
368 | | - return auth_code, pkce_params.code_verifier |
| 374 | + return result.code, pkce_params.code_verifier |
369 | 375 |
|
370 | 376 | def _get_token_endpoint(self) -> str: |
371 | 377 | if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: |
@@ -570,6 +576,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. |
570 | 576 | if not ok: |
571 | 577 | break |
572 | 578 | if ok and asm: |
| 579 | + # SEP-2468: metadata issuer must match the discovery issuer |
| 580 | + if self.context.auth_server_url is not None: |
| 581 | + validate_metadata_issuer(asm, self.context.auth_server_url) |
573 | 582 | self.context.oauth_metadata = asm |
574 | 583 | break |
575 | 584 | else: |
|
0 commit comments