Skip to content

Commit 23b5cd5

Browse files
committed
Extract OTel context from _meta in incoming requests
This works for both incoming client->server requests and server->client requests like MCP Sampling.
1 parent 6e17e4c commit 23b5cd5

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

src/mcp/shared/session.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import anyio
1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
12-
from opentelemetry.propagate import inject
12+
from opentelemetry import context as otel_context
13+
from opentelemetry.propagate import extract, inject
1314
from pydantic import BaseModel, TypeAdapter
1415
from typing_extensions import Self
1516

@@ -432,12 +433,28 @@ async def handle_message(message: SessionMessage) -> None:
432433
else: # Response or error
433434
await self._handle_response(message)
434435

436+
async def _handle_message_with_otel(message: SessionMessage) -> None:
437+
if isinstance(message.message, (JSONRPCRequest | JSONRPCNotification)) and message.message.params:
438+
meta: dict[str, str] = message.message.params.get("_meta", {})
439+
else:
440+
meta = {}
441+
442+
# Extract and then update the immutable context copy
443+
otel_token = otel_context.attach(extract(meta))
444+
message.context = contextvars.copy_context()
445+
446+
try:
447+
await handle_message(message)
448+
finally:
449+
if otel_token:
450+
otel_context.detach(otel_token)
451+
435452
async for message in self._read_stream:
436453
if isinstance(message, Exception): # pragma: no cover
437454
await self._handle_incoming(message)
438455
else:
439456
async with anyio.create_task_group() as tg:
440-
message.context.run(tg.start_soon, handle_message, message)
457+
message.context.run(tg.start_soon, _handle_message_with_otel, message)
441458

442459
except anyio.ClosedResourceError:
443460
# This is expected when the client disconnects abruptly.

tests/shared/test_otel_context_meta.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# ruff: noqa: E501
44

5+
import contextvars
56
from dataclasses import dataclass
67

78
import anyio
@@ -79,6 +80,11 @@ async def tool_with_sampling(topic: str, ctx: Context[ServerSession, None]) -> s
7980
)
8081
return "ran sampling"
8182

83+
@mcp.tool()
84+
async def tool_that_checks_trace_context() -> str:
85+
"""Returns current span details to verify parent propagation."""
86+
return trace.format_trace_id(trace.get_current_span().get_span_context().trace_id)
87+
8288
return mcp
8389

8490

@@ -106,8 +112,9 @@ async def patched_client(server: MCPServer, monkeypatch: pytest.MonkeyPatch):
106112
async def sampling_callback(
107113
context: RequestContext[ClientSession], params: types.CreateMessageRequestParams
108114
) -> types.CreateMessageResult:
115+
current_trace_id = trace.format_trace_id(trace.get_current_span().get_span_context().trace_id)
109116
return types.CreateMessageResult(
110-
role="assistant", content=TextContent(type="text", text="hello"), model="foomodel"
117+
role="assistant", content=TextContent(type="text", text=current_trace_id), model="foomodel"
111118
)
112119

113120
async with create_client_server_memory_streams() as (client_streams, server_streams):
@@ -119,7 +126,9 @@ def patch_stream_send(capture_to: list[JSONRPCMessage], stream: MemoryObjectSend
119126

120127
async def send_capture(item: SessionMessage) -> None:
121128
capture_to.append(item.message)
122-
await original_send(item)
129+
# Strip context to simulate transport boundary where context variables are not preserved
130+
new_item = SessionMessage(message=item.message, metadata=item.metadata, context=contextvars.Context())
131+
await original_send(new_item)
123132

124133
monkeypatch.setattr(stream, "send", send_capture)
125134

@@ -231,6 +240,17 @@ async def test_with_existing_meta(
231240
assert patched_client.client_to_server_messages == expect_client_to_server
232241

233242

243+
@pytest.mark.anyio
244+
async def test_trace_context_extraction(patched_client: PatchedClient):
245+
"""Test that OTEL context is successfully extracted on the receiving end."""
246+
247+
with trace.use_span(SPAN_IN_CLIENT):
248+
result = await patched_client.session.call_tool("tool_that_checks_trace_context")
249+
250+
# Verify that SPAN_IN_CLIENT was extracted and made it through to the handler
251+
assert result.content[0] == snapshot(TextContent(text="00000000000000000000000000000123"))
252+
253+
234254
@pytest.mark.anyio
235255
async def test_list_tools_with_span(patched_client: PatchedClient):
236256
"""Test that OTEL context is injected into the _meta field of a tools/list request."""
@@ -323,7 +343,11 @@ async def test_server_side_sampling_propagates_to_client(patched_client: Patched
323343
JSONRPCResponse(
324344
jsonrpc="2.0",
325345
id=0,
326-
result={"role": "assistant", "content": {"type": "text", "text": "hello"}, "model": "foomodel"},
346+
result={
347+
"role": "assistant",
348+
"content": {"type": "text", "text": "00000000000000000000000000000456"},
349+
"model": "foomodel",
350+
},
327351
),
328352
]
329353
)

0 commit comments

Comments
 (0)