Skip to content

Commit 739fa13

Browse files
committed
Bind transport sessions to the authenticated principal
Both HTTP transports now record the principal that created each session — the OAuth client together with the issuer and subject when the token verifier supplies them — and serve subsequent requests for that session only when they present the same principal. Requests presenting a different principal receive the same 404 response as for an unknown session ID, and SSE session entries are removed when the connection ends. Servers without authentication, and authentication backends other than the built-in BearerAuthBackend, are unaffected: no principal is recorded and the comparison always passes.
1 parent 1abcca2 commit 739fa13

5 files changed

Lines changed: 397 additions & 38 deletions

File tree

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import time
3-
from typing import Any
3+
from typing import Any, TypedDict
44

55
from pydantic import AnyHttpUrl
66
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
@@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken):
1919
self.scopes = auth_info.scopes
2020

2121

22+
class AuthorizationContext(TypedDict):
23+
client_id: str
24+
issuer: str | None
25+
subject: str | None
26+
27+
28+
def authorization_context(user: AuthenticatedUser) -> AuthorizationContext:
29+
"""Identify the principal `user` represents, for transports to compare
30+
against the principal that created a session. Components the token
31+
verifier does not supply are `None`, so the comparison degrades to the
32+
remaining components.
33+
34+
See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for
35+
a verifier that populates `subject` and `claims` from an introspection
36+
response."""
37+
token = user.access_token
38+
issuer = (token.claims or {}).get("iss")
39+
return AuthorizationContext(
40+
client_id=token.client_id,
41+
issuer=str(issuer) if issuer is not None else None,
42+
subject=token.subject,
43+
)
44+
45+
2246
class BearerAuthBackend(AuthenticationBackend):
2347
"""
2448
Authentication backend that validates Bearer tokens using a TokenVerifier.

src/mcp/server/sse.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
5556
from mcp.server.transport_security import (
5657
TransportSecurityMiddleware,
5758
TransportSecuritySettings,
@@ -75,6 +76,9 @@ class SseServerTransport:
7576

7677
_endpoint: str
7778
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
79+
# Identity of the credential that created each session; requests for a
80+
# session must present the same credential.
81+
_session_owners: dict[UUID, AuthorizationContext]
7882
_security: TransportSecurityMiddleware
7983

8084
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
@@ -115,6 +119,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
115119

116120
self._endpoint = endpoint
117121
self._read_stream_writers = {}
122+
self._session_owners = {}
118123
self._security = TransportSecurityMiddleware(security_settings)
119124
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
120125

@@ -142,6 +147,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag
142147
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
143148

144149
session_id = uuid4()
150+
user = scope.get("user")
151+
if isinstance(user, AuthenticatedUser):
152+
self._session_owners[session_id] = authorization_context(user)
145153
self._read_stream_writers[session_id] = read_stream_writer
146154
logger.debug(f"Created new session with ID: {session_id}")
147155

@@ -177,26 +185,34 @@ async def sse_writer():
177185
}
178186
)
179187

180-
async with anyio.create_task_group() as tg:
181-
182-
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
183-
"""
184-
The EventSourceResponse returning signals a client close / disconnect.
185-
In this case we close our side of the streams to signal the client that
186-
the connection has been closed.
187-
"""
188-
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
189-
scope, receive, send
190-
)
191-
await read_stream_writer.aclose()
192-
await write_stream_reader.aclose()
193-
logging.debug(f"Client session disconnected {session_id}")
194-
195-
logger.debug("Starting SSE response task")
196-
tg.start_soon(response_wrapper, scope, receive, send)
197-
198-
logger.debug("Yielding read and write streams")
199-
yield (read_stream, write_stream)
188+
try:
189+
async with anyio.create_task_group() as tg:
190+
191+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
192+
"""
193+
The EventSourceResponse returning signals a client close / disconnect.
194+
In this case we close our side of the streams to signal the client that
195+
the connection has been closed.
196+
"""
197+
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
198+
scope, receive, send
199+
)
200+
await read_stream_writer.aclose()
201+
await write_stream_reader.aclose()
202+
await sse_stream_reader.aclose()
203+
logging.debug(f"Client session disconnected {session_id}")
204+
205+
logger.debug("Starting SSE response task")
206+
tg.start_soon(response_wrapper, scope, receive, send)
207+
208+
logger.debug("Yielding read and write streams")
209+
yield (read_stream, write_stream)
210+
finally:
211+
# The connection is gone: stop routing messages to this session
212+
# and drop its entries so they do not accumulate for the lifetime
213+
# of the transport.
214+
self._read_stream_writers.pop(session_id, None)
215+
self._session_owners.pop(session_id, None)
200216

201217
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
202218
logger.debug("Handling POST message")
@@ -227,6 +243,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
227243
response = Response("Could not find session", status_code=404)
228244
return await response(scope, receive, send)
229245

246+
user = scope.get("user")
247+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
248+
if requestor != self._session_owners.get(session_id):
249+
# A session can only be used with the credential that created it.
250+
# Respond exactly as if the session did not exist.
251+
logger.warning("Rejecting message for session %s: credential does not match", session_id)
252+
response = Response("Could not find session", status_code=404)
253+
return await response(scope, receive, send)
254+
230255
body = await request.body()
231256
logger.debug(f"Received JSON: {body}")
232257

src/mcp/server/streamable_http_manager.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextlib
66
import logging
77
from collections.abc import AsyncIterator
8-
from http import HTTPStatus
98
from typing import Any
109
from uuid import uuid4
1110

@@ -15,6 +14,7 @@
1514
from starlette.responses import Response
1615
from starlette.types import Receive, Scope, Send
1716

17+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
1818
from mcp.server.lowlevel.server import Server as MCPServer
1919
from mcp.server.streamable_http import (
2020
MCP_SESSION_ID_HEADER,
@@ -88,6 +88,9 @@ def __init__(
8888
# Session tracking (only used if not stateless)
8989
self._session_creation_lock = anyio.Lock()
9090
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
91+
# Identity of the credential that created each session; requests for a
92+
# session must present the same credential.
93+
self._session_owners: dict[str, AuthorizationContext] = {}
9194

9295
# The task group will be set during lifespan
9396
self._task_group = None
@@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
135138
self._task_group = None
136139
# Clear any remaining server instances
137140
self._server_instances.clear()
141+
self._session_owners.clear()
138142

139143
async def handle_request(
140144
self,
@@ -227,12 +231,32 @@ async def _handle_stateful_request(
227231
request = Request(scope, receive)
228232
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
229233

234+
user = scope.get("user")
235+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
236+
230237
# Existing session case
231-
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover
238+
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
232239
transport = self._server_instances[request_mcp_session_id]
240+
if requestor != self._session_owners.get(request_mcp_session_id):
241+
# A session can only be used with the credential that created
242+
# it. Respond exactly as if the session did not exist.
243+
logger.warning(
244+
"Rejecting request for session %s: credential does not match the one that created the session",
245+
request_mcp_session_id[:64],
246+
)
247+
body = JSONRPCError(
248+
jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found")
249+
)
250+
response = Response(
251+
body.model_dump_json(by_alias=True, exclude_none=True),
252+
status_code=404,
253+
media_type="application/json",
254+
)
255+
await response(scope, receive, send)
256+
return
233257
logger.debug("Session already exists, handling request directly")
234258
# Push back idle deadline on activity
235-
if transport.idle_scope is not None and self.session_idle_timeout is not None:
259+
if transport.idle_scope is not None and self.session_idle_timeout is not None: # pragma: no cover
236260
transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout
237261
await transport.handle_request(scope, receive, send)
238262
return
@@ -251,6 +275,8 @@ async def _handle_stateful_request(
251275
)
252276

253277
assert http_transport.mcp_session_id is not None
278+
if requestor is not None:
279+
self._session_owners[http_transport.mcp_session_id] = requestor
254280
self._server_instances[http_transport.mcp_session_id] = http_transport
255281
logger.info(f"Created new transport with session ID: {new_session_id}")
256282

@@ -281,6 +307,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
281307
assert http_transport.mcp_session_id is not None
282308
logger.info(f"Session {http_transport.mcp_session_id} idle timeout")
283309
self._server_instances.pop(http_transport.mcp_session_id, None)
310+
self._session_owners.pop(http_transport.mcp_session_id, None)
284311
await http_transport.terminate()
285312
except Exception:
286313
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
@@ -296,6 +323,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
296323
"active instances."
297324
)
298325
del self._server_instances[http_transport.mcp_session_id]
326+
self._session_owners.pop(http_transport.mcp_session_id, None)
299327

300328
# Assert task group is not None for type checking
301329
assert self._task_group is not None
@@ -306,19 +334,10 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
306334
await http_transport.handle_request(scope, receive, send)
307335
else:
308336
# Unknown or expired session ID - return 404 per MCP spec
309-
# TODO: Align error code once spec clarifies
310-
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
311-
error_response = JSONRPCError(
312-
jsonrpc="2.0",
313-
id="server-error",
314-
error=ErrorData(
315-
code=INVALID_REQUEST,
316-
message="Session not found",
317-
),
337+
body = JSONRPCError(
338+
jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found")
318339
)
319340
response = Response(
320-
content=error_response.model_dump_json(by_alias=True, exclude_none=True),
321-
status_code=HTTPStatus.NOT_FOUND,
322-
media_type="application/json",
341+
body.model_dump_json(by_alias=True, exclude_none=True), status_code=404, media_type="application/json"
323342
)
324343
await response(scope, receive, send)

0 commit comments

Comments
 (0)