Skip to content

Commit ce267b6

Browse files
authored
[v1.x] Bind transport sessions to the authenticated principal (#2719)
1 parent 1abcca2 commit ce267b6

5 files changed

Lines changed: 404 additions & 38 deletions

File tree

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import time
3-
from typing import Any
3+
from typing import Any, TypedDict
44

55
from pydantic import AnyHttpUrl
66
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
@@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken):
1919
self.scopes = auth_info.scopes
2020

2121

22+
class AuthorizationContext(TypedDict):
23+
client_id: str
24+
issuer: str | None
25+
subject: str | None
26+
27+
28+
def authorization_context(user: AuthenticatedUser) -> AuthorizationContext:
29+
"""Identify the principal `user` represents, for transports to compare
30+
against the principal that created a session. Components the token
31+
verifier does not supply are `None`, so the comparison degrades to the
32+
remaining components.
33+
34+
See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for
35+
a verifier that populates `subject` and `claims` from an introspection
36+
response."""
37+
token = user.access_token
38+
issuer = (token.claims or {}).get("iss")
39+
return AuthorizationContext(
40+
client_id=token.client_id,
41+
issuer=str(issuer) if issuer is not None else None,
42+
subject=token.subject,
43+
)
44+
45+
2246
class BearerAuthBackend(AuthenticationBackend):
2347
"""
2448
Authentication backend that validates Bearer tokens using a TokenVerifier.

src/mcp/server/sse.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
5556
from mcp.server.transport_security import (
5657
TransportSecurityMiddleware,
5758
TransportSecuritySettings,
@@ -75,6 +76,9 @@ class SseServerTransport:
7576

7677
_endpoint: str
7778
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
79+
# Identity of the credential that created each session; requests for a
80+
# session must present the same credential.
81+
_session_owners: dict[UUID, AuthorizationContext]
7882
_security: TransportSecurityMiddleware
7983

8084
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
@@ -115,6 +119,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
115119

116120
self._endpoint = endpoint
117121
self._read_stream_writers = {}
122+
self._session_owners = {}
118123
self._security = TransportSecurityMiddleware(security_settings)
119124
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
120125

@@ -142,6 +147,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag
142147
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
143148

144149
session_id = uuid4()
150+
user = scope.get("user")
151+
if isinstance(user, AuthenticatedUser):
152+
self._session_owners[session_id] = authorization_context(user)
145153
self._read_stream_writers[session_id] = read_stream_writer
146154
logger.debug(f"Created new session with ID: {session_id}")
147155

@@ -177,26 +185,34 @@ async def sse_writer():
177185
}
178186
)
179187

180-
async with anyio.create_task_group() as tg:
181-
182-
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
183-
"""
184-
The EventSourceResponse returning signals a client close / disconnect.
185-
In this case we close our side of the streams to signal the client that
186-
the connection has been closed.
187-
"""
188-
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
189-
scope, receive, send
190-
)
191-
await read_stream_writer.aclose()
192-
await write_stream_reader.aclose()
193-
logging.debug(f"Client session disconnected {session_id}")
194-
195-
logger.debug("Starting SSE response task")
196-
tg.start_soon(response_wrapper, scope, receive, send)
197-
198-
logger.debug("Yielding read and write streams")
199-
yield (read_stream, write_stream)
188+
try:
189+
async with anyio.create_task_group() as tg:
190+
191+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
192+
"""
193+
The EventSourceResponse returning signals a client close / disconnect.
194+
In this case we close our side of the streams to signal the client that
195+
the connection has been closed.
196+
"""
197+
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
198+
scope, receive, send
199+
)
200+
await read_stream_writer.aclose()
201+
await write_stream_reader.aclose()
202+
await sse_stream_reader.aclose()
203+
logging.debug(f"Client session disconnected {session_id}")
204+
205+
logger.debug("Starting SSE response task")
206+
tg.start_soon(response_wrapper, scope, receive, send)
207+
208+
logger.debug("Yielding read and write streams")
209+
yield (read_stream, write_stream)
210+
finally:
211+
# The connection is gone: stop routing messages to this session
212+
# and drop its entries so they do not accumulate for the lifetime
213+
# of the transport.
214+
self._read_stream_writers.pop(session_id, None)
215+
self._session_owners.pop(session_id, None)
200216

201217
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
202218
logger.debug("Handling POST message")
@@ -227,6 +243,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
227243
response = Response("Could not find session", status_code=404)
228244
return await response(scope, receive, send)
229245

246+
user = scope.get("user")
247+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
248+
if requestor != self._session_owners.get(session_id):
249+
# A session can only be used with the credential that created it.
250+
# Respond exactly as if the session did not exist.
251+
logger.warning("Rejecting message for session %s: credential does not match", session_id)
252+
response = Response("Could not find session", status_code=404)
253+
return await response(scope, receive, send)
254+
230255
body = await request.body()
231256
logger.debug(f"Received JSON: {body}")
232257

src/mcp/server/streamable_http_manager.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextlib
66
import logging
77
from collections.abc import AsyncIterator
8-
from http import HTTPStatus
98
from typing import Any
109
from uuid import uuid4
1110

@@ -15,6 +14,7 @@
1514
from starlette.responses import Response
1615
from starlette.types import Receive, Scope, Send
1716

17+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
1818
from mcp.server.lowlevel.server import Server as MCPServer
1919
from mcp.server.streamable_http import (
2020
MCP_SESSION_ID_HEADER,
@@ -88,6 +88,9 @@ def __init__(
8888
# Session tracking (only used if not stateless)
8989
self._session_creation_lock = anyio.Lock()
9090
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
91+
# Identity of the credential that created each session; requests for a
92+
# session must present the same credential.
93+
self._session_owners: dict[str, AuthorizationContext] = {}
9194

9295
# The task group will be set during lifespan
9396
self._task_group = None
@@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
135138
self._task_group = None
136139
# Clear any remaining server instances
137140
self._server_instances.clear()
141+
self._session_owners.clear()
138142

139143
async def handle_request(
140144
self,
@@ -227,12 +231,32 @@ async def _handle_stateful_request(
227231
request = Request(scope, receive)
228232
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
229233

234+
user = scope.get("user")
235+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
236+
230237
# Existing session case
231-
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover
238+
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
232239
transport = self._server_instances[request_mcp_session_id]
240+
if requestor != self._session_owners.get(request_mcp_session_id):
241+
# A session can only be used with the credential that created
242+
# it. Respond exactly as if the session did not exist.
243+
logger.warning(
244+
"Rejecting request for session %s: credential does not match the one that created the session",
245+
request_mcp_session_id[:64],
246+
)
247+
body = JSONRPCError(
248+
jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found")
249+
)
250+
response = Response(
251+
body.model_dump_json(by_alias=True, exclude_none=True),
252+
status_code=404,
253+
media_type="application/json",
254+
)
255+
await response(scope, receive, send)
256+
return
233257
logger.debug("Session already exists, handling request directly")
234258
# Push back idle deadline on activity
235-
if transport.idle_scope is not None and self.session_idle_timeout is not None:
259+
if transport.idle_scope is not None and self.session_idle_timeout is not None: # pragma: no cover
236260
transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout
237261
await transport.handle_request(scope, receive, send)
238262
return
@@ -251,6 +275,8 @@ async def _handle_stateful_request(
251275
)
252276

253277
assert http_transport.mcp_session_id is not None
278+
if requestor is not None:
279+
self._session_owners[http_transport.mcp_session_id] = requestor
254280
self._server_instances[http_transport.mcp_session_id] = http_transport
255281
logger.info(f"Created new transport with session ID: {new_session_id}")
256282

@@ -281,6 +307,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
281307
assert http_transport.mcp_session_id is not None
282308
logger.info(f"Session {http_transport.mcp_session_id} idle timeout")
283309
self._server_instances.pop(http_transport.mcp_session_id, None)
310+
self._session_owners.pop(http_transport.mcp_session_id, None)
284311
await http_transport.terminate()
285312
except Exception:
286313
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
@@ -296,6 +323,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
296323
"active instances."
297324
)
298325
del self._server_instances[http_transport.mcp_session_id]
326+
self._session_owners.pop(http_transport.mcp_session_id, None)
299327

300328
# Assert task group is not None for type checking
301329
assert self._task_group is not None
@@ -306,19 +334,10 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
306334
await http_transport.handle_request(scope, receive, send)
307335
else:
308336
# Unknown or expired session ID - return 404 per MCP spec
309-
# TODO: Align error code once spec clarifies
310-
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
311-
error_response = JSONRPCError(
312-
jsonrpc="2.0",
313-
id="server-error",
314-
error=ErrorData(
315-
code=INVALID_REQUEST,
316-
message="Session not found",
317-
),
337+
body = JSONRPCError(
338+
jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found")
318339
)
319340
response = Response(
320-
content=error_response.model_dump_json(by_alias=True, exclude_none=True),
321-
status_code=HTTPStatus.NOT_FOUND,
322-
media_type="application/json",
341+
body.model_dump_json(by_alias=True, exclude_none=True), status_code=404, media_type="application/json"
323342
)
324343
await response(scope, receive, send)

0 commit comments

Comments
 (0)