Skip to content

Commit 56dbdba

Browse files
committed
Bind transport sessions to the authenticated principal
Both HTTP transports now record the principal that created each session — the OAuth client together with the issuer and subject when the token verifier supplies them — and serve subsequent requests for that session only when they present the same principal. Requests presenting a different principal receive the same 404 response as for an unknown session ID, and SSE session entries are removed when the connection ends. Servers without authentication, and authentication backends other than the built-in BearerAuthBackend, are unaffected: no principal is recorded and the comparison always passes. The new in-process SSE tests bring connect_sse, handle_post_message, and TransportSecurityMiddleware under tracked coverage, so the corresponding no-cover pragmas are removed.
1 parent 2472563 commit 56dbdba

7 files changed

Lines changed: 637 additions & 42 deletions

File tree

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import time
3-
from typing import Any
3+
from typing import Any, TypedDict
44

55
from pydantic import AnyHttpUrl
66
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
@@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken):
1919
self.scopes = auth_info.scopes
2020

2121

22+
class AuthorizationContext(TypedDict):
23+
client_id: str
24+
issuer: str | None
25+
subject: str | None
26+
27+
28+
def authorization_context(user: AuthenticatedUser) -> AuthorizationContext:
29+
"""Identify the principal `user` represents, for transports to compare
30+
against the principal that created a session. Components the token
31+
verifier does not supply are `None`, so the comparison degrades to the
32+
remaining components.
33+
34+
See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for
35+
a verifier that populates `subject` and `claims` from an introspection
36+
response."""
37+
token = user.access_token
38+
issuer = (token.claims or {}).get("iss")
39+
return AuthorizationContext(
40+
client_id=token.client_id,
41+
issuer=str(issuer) if issuer is not None else None,
42+
subject=token.subject,
43+
)
44+
45+
2246
class BearerAuthBackend(AuthenticationBackend):
2347
"""Authentication backend that validates Bearer tokens using a TokenVerifier."""
2448

src/mcp/server/sse.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ async def handle_sse(request):
5050
from starlette.types import Receive, Scope, Send
5151

5252
from mcp import types
53+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
5354
from mcp.server.transport_security import (
5455
TransportSecurityMiddleware,
5556
TransportSecuritySettings,
@@ -73,6 +74,9 @@ class SseServerTransport:
7374

7475
_endpoint: str
7576
_read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]]
77+
# Identity of the credential that created each session; requests for a
78+
# session must present the same credential.
79+
_session_owners: dict[UUID, AuthorizationContext]
7680
_security: TransportSecurityMiddleware
7781

7882
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
@@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
112116

113117
self._endpoint = endpoint
114118
self._read_stream_writers = {}
119+
self._session_owners = {}
115120
self._security = TransportSecurityMiddleware(security_settings)
116121
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
117122

118123
@asynccontextmanager
119124
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
120-
if scope["type"] != "http": # pragma: no cover
125+
if scope["type"] != "http":
121126
logger.error("connect_sse received non-HTTP request")
122127
raise ValueError("connect_sse can only handle HTTP requests")
123128

124129
# Validate request headers for DNS rebinding protection
125130
request = Request(scope, receive)
126131
error_response = await self._security.validate_request(request, is_post=False)
127-
if error_response: # pragma: no cover
132+
if error_response:
128133
await error_response(scope, receive, send)
129134
raise ValueError("Request validation failed")
130135

@@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
134139
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
135140

136141
session_id = uuid4()
142+
user = scope.get("user")
143+
if isinstance(user, AuthenticatedUser):
144+
self._session_owners[session_id] = authorization_context(user)
137145
self._read_stream_writers[session_id] = read_stream_writer
138146
logger.debug(f"Created new session with ID: {session_id}")
139147

@@ -169,35 +177,38 @@ async def sse_writer():
169177
}
170178
)
171179

172-
async with anyio.create_task_group() as tg:
173-
174-
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
175-
"""The EventSourceResponse returning signals a client close / disconnect.
176-
In this case we close our side of the streams to signal the client that
177-
the connection has been closed.
178-
"""
179-
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
180-
scope, receive, send
181-
)
182-
await sse_stream_reader.aclose()
183-
await read_stream_writer.aclose()
184-
await write_stream_reader.aclose()
185-
self._read_stream_writers.pop(session_id, None)
186-
logging.debug(f"Client session disconnected {session_id}")
180+
try:
181+
async with anyio.create_task_group() as tg:
182+
183+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
184+
"""The EventSourceResponse returning signals a client close / disconnect.
185+
In this case we close our side of the streams to signal the client that
186+
the connection has been closed.
187+
"""
188+
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
189+
scope, receive, send
190+
)
191+
await read_stream_writer.aclose()
192+
await write_stream_reader.aclose()
193+
await sse_stream_reader.aclose()
194+
logging.debug(f"Client session disconnected {session_id}")
187195

188-
logger.debug("Starting SSE response task")
189-
tg.start_soon(response_wrapper, scope, receive, send)
196+
logger.debug("Starting SSE response task")
197+
tg.start_soon(response_wrapper, scope, receive, send)
190198

191-
logger.debug("Yielding read and write streams")
192-
yield (read_stream, write_stream)
199+
logger.debug("Yielding read and write streams")
200+
yield (read_stream, write_stream)
201+
finally:
202+
self._read_stream_writers.pop(session_id, None)
203+
self._session_owners.pop(session_id, None)
193204

194205
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
195206
logger.debug("Handling POST message")
196207
request = Request(scope, receive)
197208

198209
# Validate request headers for DNS rebinding protection
199210
error_response = await self._security.validate_request(request, is_post=True)
200-
if error_response: # pragma: no cover
211+
if error_response:
201212
return await error_response(scope, receive, send)
202213

203214
session_id_param = request.query_params.get("session_id")
@@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
220231
response = Response("Could not find session", status_code=404)
221232
return await response(scope, receive, send)
222233

234+
user = scope.get("user")
235+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
236+
if requestor != self._session_owners.get(session_id):
237+
# A session can only be used with the credential that created it.
238+
# Respond exactly as if the session did not exist.
239+
logger.warning("Rejecting message for session %s: credential does not match", session_id)
240+
response = Response("Could not find session", status_code=404)
241+
return await response(scope, receive, send)
242+
223243
body = await request.body()
224244
logger.debug(f"Received JSON: {body}")
225245

226246
try:
227247
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
228248
logger.debug(f"Validated client message: {message}")
229-
except ValidationError as err: # pragma: no cover
249+
except ValidationError as err:
230250
logger.exception("Failed to parse message")
231251
response = Response("Could not parse message", status_code=400)
232252
await response(scope, receive, send)

src/mcp/server/streamable_http_manager.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextlib
66
import logging
77
from collections.abc import AsyncIterator
8-
from http import HTTPStatus
98
from typing import TYPE_CHECKING, Any
109
from uuid import uuid4
1110

@@ -15,6 +14,7 @@
1514
from starlette.responses import Response
1615
from starlette.types import Receive, Scope, Send
1716

17+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
1818
from mcp.server.streamable_http import (
1919
MCP_SESSION_ID_HEADER,
2020
EventStore,
@@ -89,6 +89,9 @@ def __init__(
8989
# Session tracking (only used if not stateless)
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
92+
# Identity of the credential that created each session; requests for a
93+
# session must present the same credential.
94+
self._session_owners: dict[str, AuthorizationContext] = {}
9295

9396
# The task group will be set during lifespan
9497
self._task_group = None
@@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
135138
self._task_group = None
136139
# Clear any remaining server instances
137140
self._server_instances.clear()
141+
self._session_owners.clear()
138142

139143
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140144
"""Process ASGI request with proper session handling and transport setup.
@@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
192196
request = Request(scope, receive)
193197
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
194198

199+
user = scope.get("user")
200+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
201+
195202
# Existing session case
196203
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
197204
transport = self._server_instances[request_mcp_session_id]
205+
if requestor != self._session_owners.get(request_mcp_session_id):
206+
# A session can only be used with the credential that created
207+
# it. Respond exactly as if the session did not exist.
208+
logger.warning(
209+
"Rejecting request for session %s: credential does not match the one that created the session",
210+
request_mcp_session_id[:64],
211+
)
212+
body = JSONRPCError(
213+
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
214+
)
215+
response = Response(
216+
body.model_dump_json(by_alias=True, exclude_unset=True),
217+
status_code=404,
218+
media_type="application/json",
219+
)
220+
await response(scope, receive, send)
221+
return
198222
logger.debug("Session already exists, handling request directly")
199223
# Push back idle deadline on activity
200224
if transport.idle_scope is not None and self.session_idle_timeout is not None:
@@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
216240
)
217241

218242
assert http_transport.mcp_session_id is not None
243+
if requestor is not None:
244+
self._session_owners[http_transport.mcp_session_id] = requestor
219245
self._server_instances[http_transport.mcp_session_id] = http_transport
220246
logger.info(f"Created new transport with session ID: {new_session_id}")
221247

@@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
246272
assert http_transport.mcp_session_id is not None
247273
logger.info(f"Session {http_transport.mcp_session_id} idle timeout")
248274
self._server_instances.pop(http_transport.mcp_session_id, None)
275+
self._session_owners.pop(http_transport.mcp_session_id, None)
249276
await http_transport.terminate()
250277
except Exception:
251278
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
@@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
260287
f"{http_transport.mcp_session_id} from active instances."
261288
)
262289
del self._server_instances[http_transport.mcp_session_id]
290+
self._session_owners.pop(http_transport.mcp_session_id, None)
263291

264292
# Assert task group is not None for type checking
265293
assert self._task_group is not None
@@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
273301
# TODO: Align error code once spec clarifies
274302
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
275303
logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}")
276-
error_response = JSONRPCError(
277-
jsonrpc="2.0",
278-
id=None,
279-
error=ErrorData(code=INVALID_REQUEST, message="Session not found"),
304+
body = JSONRPCError(
305+
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
280306
)
281307
response = Response(
282-
content=error_response.model_dump_json(by_alias=True, exclude_unset=True),
283-
status_code=HTTPStatus.NOT_FOUND,
284-
media_type="application/json",
308+
body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json"
285309
)
286310
await response(scope, receive, send)
287311

src/mcp/server/transport_security.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,17 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4242

4343
def _validate_host(self, host: str | None) -> bool:
4444
"""Validate the Host header against allowed values."""
45-
if not host: # pragma: no cover
45+
if not host:
4646
logger.warning("Missing Host header in request")
4747
return False
4848

4949
# Check exact match first
50-
if host in self.settings.allowed_hosts: # pragma: no cover
50+
if host in self.settings.allowed_hosts:
5151
return True
5252

5353
# Check wildcard port patterns
5454
for allowed in self.settings.allowed_hosts:
55-
if allowed.endswith(":*"): # pragma: no branch
55+
if allowed.endswith(":*"):
5656
# Extract base host from pattern
5757
base_host = allowed[:-2]
5858
# Check if the actual host starts with base host and has a port
@@ -65,16 +65,16 @@ def _validate_host(self, host: str | None) -> bool:
6565
def _validate_origin(self, origin: str | None) -> bool:
6666
"""Validate the Origin header against allowed values."""
6767
# Origin can be absent for same-origin requests
68-
if not origin: # pragma: no cover
68+
if not origin:
6969
return True
7070

7171
# Check exact match first
72-
if origin in self.settings.allowed_origins: # pragma: no cover
72+
if origin in self.settings.allowed_origins:
7373
return True
7474

7575
# Check wildcard port patterns
7676
for allowed in self.settings.allowed_origins:
77-
if allowed.endswith(":*"): # pragma: no branch
77+
if allowed.endswith(":*"):
7878
# Extract base origin from pattern
7979
base_origin = allowed[:-2]
8080
# Check if the actual origin starts with base origin and has a port
@@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
9494
Returns None if validation passes, or an error Response if validation fails.
9595
"""
9696
# Always validate Content-Type for POST requests
97-
if is_post: # pragma: no branch
97+
if is_post:
9898
content_type = request.headers.get("content-type")
9999
if not self._validate_content_type(content_type):
100100
return Response("Invalid Content-Type header", status_code=400)

0 commit comments

Comments
 (0)