@@ -52,6 +52,7 @@ async def handle_sse(request):
5252from starlette .types import Receive , Scope , Send
5353
5454import mcp .types as types
55+ from mcp .server .auth .middleware .bearer_auth import AuthenticatedUser , AuthorizationContext , authorization_context
5556from mcp .server .transport_security import (
5657 TransportSecurityMiddleware ,
5758 TransportSecuritySettings ,
@@ -75,6 +76,9 @@ class SseServerTransport:
7576
7677 _endpoint : str
7778 _read_stream_writers : dict [UUID , MemoryObjectSendStream [SessionMessage | Exception ]]
79+ # Identity of the credential that created each session; requests for a
80+ # session must present the same credential.
81+ _session_owners : dict [UUID , AuthorizationContext ]
7882 _security : TransportSecurityMiddleware
7983
8084 def __init__ (self , endpoint : str , security_settings : TransportSecuritySettings | None = None ) -> None :
@@ -115,6 +119,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
115119
116120 self ._endpoint = endpoint
117121 self ._read_stream_writers = {}
122+ self ._session_owners = {}
118123 self ._security = TransportSecurityMiddleware (security_settings )
119124 logger .debug (f"SseServerTransport initialized with endpoint: { endpoint } " )
120125
@@ -142,6 +147,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag
142147 write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
143148
144149 session_id = uuid4 ()
150+ user = scope .get ("user" )
151+ if isinstance (user , AuthenticatedUser ):
152+ self ._session_owners [session_id ] = authorization_context (user )
145153 self ._read_stream_writers [session_id ] = read_stream_writer
146154 logger .debug (f"Created new session with ID: { session_id } " )
147155
@@ -177,26 +185,34 @@ async def sse_writer():
177185 }
178186 )
179187
180- async with anyio .create_task_group () as tg :
181-
182- async def response_wrapper (scope : Scope , receive : Receive , send : Send ):
183- """
184- The EventSourceResponse returning signals a client close / disconnect.
185- In this case we close our side of the streams to signal the client that
186- the connection has been closed.
187- """
188- await EventSourceResponse (content = sse_stream_reader , data_sender_callable = sse_writer )(
189- scope , receive , send
190- )
191- await read_stream_writer .aclose ()
192- await write_stream_reader .aclose ()
193- logging .debug (f"Client session disconnected { session_id } " )
194-
195- logger .debug ("Starting SSE response task" )
196- tg .start_soon (response_wrapper , scope , receive , send )
197-
198- logger .debug ("Yielding read and write streams" )
199- yield (read_stream , write_stream )
188+ try :
189+ async with anyio .create_task_group () as tg :
190+
191+ async def response_wrapper (scope : Scope , receive : Receive , send : Send ):
192+ """
193+ The EventSourceResponse returning signals a client close / disconnect.
194+ In this case we close our side of the streams to signal the client that
195+ the connection has been closed.
196+ """
197+ await EventSourceResponse (content = sse_stream_reader , data_sender_callable = sse_writer )(
198+ scope , receive , send
199+ )
200+ await read_stream_writer .aclose ()
201+ await write_stream_reader .aclose ()
202+ await sse_stream_reader .aclose ()
203+ logging .debug (f"Client session disconnected { session_id } " )
204+
205+ logger .debug ("Starting SSE response task" )
206+ tg .start_soon (response_wrapper , scope , receive , send )
207+
208+ logger .debug ("Yielding read and write streams" )
209+ yield (read_stream , write_stream )
210+ finally :
211+ # The connection is gone: stop routing messages to this session
212+ # and drop its entries so they do not accumulate for the lifetime
213+ # of the transport.
214+ self ._read_stream_writers .pop (session_id , None )
215+ self ._session_owners .pop (session_id , None )
200216
201217 async def handle_post_message (self , scope : Scope , receive : Receive , send : Send ) -> None : # pragma: no cover
202218 logger .debug ("Handling POST message" )
@@ -227,6 +243,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
227243 response = Response ("Could not find session" , status_code = 404 )
228244 return await response (scope , receive , send )
229245
246+ user = scope .get ("user" )
247+ requestor = authorization_context (user ) if isinstance (user , AuthenticatedUser ) else None
248+ if requestor != self ._session_owners .get (session_id ):
249+ # A session can only be used with the credential that created it.
250+ # Respond exactly as if the session did not exist.
251+ logger .warning ("Rejecting message for session %s: credential does not match" , session_id )
252+ response = Response ("Could not find session" , status_code = 404 )
253+ return await response (scope , receive , send )
254+
230255 body = await request .body ()
231256 logger .debug (f"Received JSON: { body } " )
232257
0 commit comments