22
33# ruff: noqa: E501
44
5+ import contextvars
56from dataclasses import dataclass
67
78import anyio
@@ -79,6 +80,11 @@ async def tool_with_sampling(topic: str, ctx: Context[ServerSession, None]) -> s
7980 )
8081 return "ran sampling"
8182
83+ @mcp .tool ()
84+ async def tool_that_checks_trace_context () -> str :
85+ """Returns current span details to verify parent propagation."""
86+ return trace .format_trace_id (trace .get_current_span ().get_span_context ().trace_id )
87+
8288 return mcp
8389
8490
@@ -106,8 +112,9 @@ async def patched_client(server: MCPServer, monkeypatch: pytest.MonkeyPatch):
106112 async def sampling_callback (
107113 context : RequestContext [ClientSession ], params : types .CreateMessageRequestParams
108114 ) -> types .CreateMessageResult :
115+ current_trace_id = trace .format_trace_id (trace .get_current_span ().get_span_context ().trace_id )
109116 return types .CreateMessageResult (
110- role = "assistant" , content = TextContent (type = "text" , text = "hello" ), model = "foomodel"
117+ role = "assistant" , content = TextContent (type = "text" , text = current_trace_id ), model = "foomodel"
111118 )
112119
113120 async with create_client_server_memory_streams () as (client_streams , server_streams ):
@@ -119,7 +126,9 @@ def patch_stream_send(capture_to: list[JSONRPCMessage], stream: MemoryObjectSend
119126
120127 async def send_capture (item : SessionMessage ) -> None :
121128 capture_to .append (item .message )
122- await original_send (item )
129+ # Strip context to simulate transport boundary where context variables are not preserved
130+ new_item = SessionMessage (message = item .message , metadata = item .metadata , context = contextvars .Context ())
131+ await original_send (new_item )
123132
124133 monkeypatch .setattr (stream , "send" , send_capture )
125134
@@ -231,6 +240,17 @@ async def test_with_existing_meta(
231240 assert patched_client .client_to_server_messages == expect_client_to_server
232241
233242
243+ @pytest .mark .anyio
244+ async def test_trace_context_extraction (patched_client : PatchedClient ):
245+ """Test that OTEL context is successfully extracted on the receiving end."""
246+
247+ with trace .use_span (SPAN_IN_CLIENT ):
248+ result = await patched_client .session .call_tool ("tool_that_checks_trace_context" )
249+
250+ # Verify that SPAN_IN_CLIENT was extracted and made it through to the handler
251+ assert result .content [0 ] == snapshot (TextContent (text = "00000000000000000000000000000123" ))
252+
253+
234254@pytest .mark .anyio
235255async def test_list_tools_with_span (patched_client : PatchedClient ):
236256 """Test that OTEL context is injected into the _meta field of a tools/list request."""
@@ -323,7 +343,11 @@ async def test_server_side_sampling_propagates_to_client(patched_client: Patched
323343 JSONRPCResponse (
324344 jsonrpc = "2.0" ,
325345 id = 0 ,
326- result = {"role" : "assistant" , "content" : {"type" : "text" , "text" : "hello" }, "model" : "foomodel" },
346+ result = {
347+ "role" : "assistant" ,
348+ "content" : {"type" : "text" , "text" : "00000000000000000000000000000456" },
349+ "model" : "foomodel" ,
350+ },
327351 ),
328352 ]
329353 )
0 commit comments