diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 4b5caa9cc..cc979f815 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -4,7 +4,7 @@ from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession from .server.stdio import stdio_server -from .shared.exceptions import MCPError, UrlElicitationRequiredError +from .shared.exceptions import HttpError, MCPError, UrlElicitationRequiredError from .types import ( CallToolRequest, ClientCapabilities, @@ -81,6 +81,7 @@ "ErrorData", "GetPromptRequest", "GetPromptResult", + "HttpError", "Implementation", "IncludeContext", "InitializeRequest", diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c..03d459242 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -13,6 +13,7 @@ from mcp import types from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared.exceptions import HttpError from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -142,10 +143,28 @@ async def post_writer(endpoint_url: str): exclude_unset=True, ), ) + if response.status_code in (401, 403): + status_label = "Unauthorized" if response.status_code == 401 else "Forbidden" + exc = HttpError( + response.status_code, + f"HTTP {response.status_code} {status_label}", + ) + await read_stream_writer.send(exc) + raise exc response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception: # pragma: lax no cover + except HttpError: # pragma: lax no cover + raise + except httpx.HTTPStatusError as exc: # pragma: lax no cover logger.exception("Error in post_writer") + http_exc = HttpError( + exc.response.status_code, + f"HTTP {exc.response.status_code}", + ) + await read_stream_writer.send(http_exc) + except Exception as exc: # pragma: lax no cover + logger.exception("Error in post_writer") + await read_stream_writer.send(exc) finally: await write_stream.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0b..4d3e10c56 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -17,6 +17,7 @@ from mcp.client._transport import TransportStreams from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared.exceptions import HttpError from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( INTERNAL_ERROR, @@ -269,17 +270,41 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: if response.status_code == 404: # pragma: no branch if isinstance(message, JSONRPCRequest): # pragma: no branch - error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") + error_data = ErrorData( + code=INVALID_REQUEST, + message="Session terminated (HTTP 404)", + data={"http_status": 404}, + ) session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) + else: + raise HttpError(404, "Session terminated (HTTP 404)") return + if response.status_code in (401, 403): + status_label = "Unauthorized" if response.status_code == 401 else "Forbidden" + error_message = f"HTTP {response.status_code} {status_label}" + if isinstance(message, JSONRPCRequest): + error_data = ErrorData( + code=INTERNAL_ERROR, + message=error_message, + data={"http_status": response.status_code}, + ) + session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) + await ctx.read_stream_writer.send(session_message) + raise HttpError(response.status_code, error_message) + if response.status_code >= 400: + error_message = f"HTTP {response.status_code}" if isinstance(message, JSONRPCRequest): - error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") + error_data = ErrorData( + code=INTERNAL_ERROR, + message=error_message, + data={"http_status": response.status_code}, + ) session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) - return + raise HttpError(response.status_code, error_message) if is_initialization: self._maybe_extract_session_id_from_response(response) @@ -467,10 +492,14 @@ async def post_writer( ) async def handle_request_async(): - if is_resumption: - await self._handle_resumption_request(ctx) - else: - await self._handle_post_request(ctx) + try: + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + except Exception as exc: + logger.exception("Error handling request") + await read_stream_writer.send(exc) # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): @@ -478,8 +507,9 @@ async def handle_request_async(): else: await handle_request_async() - except Exception: # pragma: lax no cover + except Exception as exc: # pragma: lax no cover logger.exception("Error in post_writer") + await read_stream_writer.send(exc) finally: await read_stream_writer.aclose() await write_stream.aclose() diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319..7cd84654d 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -41,6 +41,26 @@ def __str__(self) -> str: return self.message +class HttpError(Exception): + """Raised when an MCP HTTP transport receives a non-2xx response. + + Preserves the original HTTP status code so callers can distinguish + auth errors (401/403) from other failures (404, 5xx, etc.). + """ + + def __init__(self, status_code: int, message: str | None = None, body: str | None = None): + self.status_code = status_code + self.body = body + if message is None: + message = f"HTTP {status_code}" + super().__init__(message) + + @property + def is_auth_error(self) -> bool: + """True for 401 Unauthorized or 403 Forbidden responses.""" + return self.status_code in (401, 403) + + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode.