Skip to content

Commit b598b22

Browse files
committed
Drive task-augmented tool calls transparently in the client
A client that declares the tasks extension now handles CreateTaskResult everywhere SEP-2663 requires it: - ClientSession widens its tools/call adapter per session when the extension is declared, and call_tool gains allow_create_task overloads mirroring allow_input_required (the guarded default raises with guidance instead of a pydantic ValidationError). - Client.call_tool composes a polling driver after the input_required driver: a CreateTaskResult is polled via tasks/get (honoring pollIntervalMs with a 1s fallback) until terminal, the inlined result re-enters output-schema validation, and the public contract stays CallToolResult. failed, cancelled, and input_required tasks surface as typed TaskFailedError, TaskCancelledError, and TaskInputRequiredError. - The tasks story's modern path becomes the plain typed call_tool (the same line demonstrates legacy degradation), with a compact manual leg over the mcp.shared.tasks wrappers.
1 parent 2ab9402 commit b598b22

10 files changed

Lines changed: 633 additions & 53 deletions

File tree

docs/migration.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,18 @@ Two reference extensions ship in their own modules:
470470
keeps completed tasks in a pluggable `TaskStore` (`Tasks(store=...)`,
471471
in-memory default) that enforces `default_ttl_ms`. A `tasks/*` call from a
472472
non-declaring modern client is rejected with `-32021` (missing required
473-
client capability); legacy calls get `METHOD_NOT_FOUND`. This is the core
474-
SEP-2663 surface; background execution (`working` tasks), the in-task
475-
`input_required` loop over `tasks/update`, `notifications/tasks`, and task
476-
routing headers are deferred.
473+
client capability); legacy calls get `METHOD_NOT_FOUND`. On the client side,
474+
a `Client` that declares the extension gets transparent polling:
475+
`Client.call_tool` recognises the `CreateTaskResult`, polls `tasks/get`
476+
(honoring `pollIntervalMs`), and returns the final `CallToolResult`
477+
unchanged, while `failed`/`cancelled` tasks surface as the typed
478+
`TaskFailedError`/`TaskCancelledError`. Manual driving stays available —
479+
`client.session.call_tool(..., allow_create_task=True)` returns the typed
480+
`CreateTaskResult`, and the `mcp.shared.tasks` request wrappers drive
481+
`tasks/get`/`tasks/update`/`tasks/cancel` over `session.send_request`. This
482+
is the core SEP-2663 surface; background execution (`working` tasks), the
483+
in-task `input_required` loop over `tasks/update`, `notifications/tasks`,
484+
and task routing headers are deferred.
477485

478486
Extension methods are strictly additive: a `MethodBinding` cannot name a
479487
spec-defined request method, and registering one whose method collides with

examples/stories/tasks/README.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33
Task-augmented execution (SEP-2663). A client declares the
44
`io.modelcontextprotocol/tasks` extension; the server may then answer a
55
`tools/call` with a `CreateTaskResult` (carrying a task id) instead of the
6-
`CallToolResult`, and the client fetches the result via `tasks/get`.
6+
`CallToolResult`. `Client.call_tool` drives the polling transparently and
7+
surfaces only the final result — the SEP's recommended client shape.
78

89
## Run it
910

1011
```bash
1112
# stdio (default) — today's stdio negotiates the legacy wire, which cannot carry
1213
# the extension capability, so this leg demonstrates graceful degradation: the
13-
# same tools/call returns a plain CallToolResult, never a task.
14+
# same call_tool returns a plain CallToolResult, never a task.
1415
uv run python -m stories.tasks.client
1516

1617
# HTTP — the modern wire negotiates the extension; the server defers the call as
17-
# a task and the client reads the result back via tasks/get
18+
# a task, Client.call_tool polls it to completion, and a manual leg shows the
19+
# raw CreateTaskResult -> tasks/get wire flow
1820
uv run python -m stories.tasks.client --http
1921
```
2022

@@ -29,9 +31,14 @@ uv run python -m stories.tasks.client --http
2931
declared the extension on the request, returning a flat `CreateTaskResult`
3032
(`resultType: "task"`).
3133
- `client.py` `Client(target, extensions={EXTENSION_ID: {}})` — declaring the
32-
extension is what lets the server defer; `main` then reads the `CreateTaskResult`
33-
and fetches `tasks/get`, whose completed envelope inlines the original
34-
`CallToolResult`.
34+
extension is what lets the server defer. The transparent path is then just
35+
`await client.call_tool(...)`: `Client` recognises the `CreateTaskResult`,
36+
polls `tasks/get` (honoring `pollIntervalMs`), and returns the final
37+
`CallToolResult`; a `failed` task raises `TaskFailedError`.
38+
- The manual leg — `session.call_tool(..., allow_create_task=True)` returns the
39+
typed `CreateTaskResult` (mirroring `allow_input_required`), and the shared
40+
`mcp.shared.tasks` wrappers (`GetTaskRequest`/`GetTaskResult`) drive `tasks/get`
41+
by hand over `session.send_request`.
3542

3643
## Scope
3744

examples/stories/tasks/client.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,58 @@
1-
"""Declare the tasks extension, let the server defer a tool call, then fetch the result via tasks/get.
1+
"""Declare the tasks extension and let `Client.call_tool` drive the task transparently.
22
33
The client declares `io.modelcontextprotocol/tasks` (via `Client(extensions=...)`),
4-
so the server is free to answer `tools/call` with a `CreateTaskResult`. `Client`
5-
exposes only spec verbs, so the augmented call and `tasks/get` drop to
6-
`client.session`; the thin `_send` helper keeps that out of the story below.
4+
so the server is free to answer `tools/call` with a `CreateTaskResult`. SEP-2663
5+
advises clients to keep a fixed public contract and drive the polling internally —
6+
`Client.call_tool` does exactly that, so the modern path is the same typed call a
7+
task-less server would get. A compact manual leg then shows the raw wire flow:
8+
`session.call_tool(allow_create_task=True)` for the typed `CreateTaskResult`, and
9+
the shared `mcp.shared.tasks` wrappers over `session.send_request` for `tasks/get`.
710
"""
811

9-
from typing import Any, Literal, cast
12+
from typing import cast
1013

1114
import mcp_types as types
12-
from pydantic import TypeAdapter
1315

14-
from mcp.client import Client, ClientSession
15-
from mcp.server.tasks import EXTENSION_ID, GetTaskRequestParams
16+
from mcp.client import Client
17+
from mcp.server.tasks import EXTENSION_ID
18+
from mcp.shared.tasks import CreateTaskResult, GetTaskRequest, GetTaskRequestParams, GetTaskResult
1619
from stories._harness import Target, run_client
1720

18-
_RAW: TypeAdapter[dict[str, Any]] = TypeAdapter(dict)
19-
20-
21-
class _GetTaskRequest(types.Request[GetTaskRequestParams, Literal["tasks/get"]]):
22-
method: Literal["tasks/get"] = "tasks/get"
23-
params: GetTaskRequestParams
24-
25-
26-
async def _send(session: ClientSession, request: types.Request[Any, Any]) -> dict[str, Any]:
27-
"""Send a request whose result has a non-spec (extension) shape; return the raw dict."""
28-
return await session.send_request(cast("types.ClientRequest", request), cast("Any", _RAW))
29-
3021

3122
async def main(target: Target, *, mode: str = "auto") -> None:
3223
async with Client(target, mode=mode, extensions={EXTENSION_ID: {}}) as client:
33-
# The extension is a modern-only capability negotiated over server/discover.
34-
# A legacy connection (today's stdio) cannot carry it, and the server then
35-
# must not augment: the same tools/call degrades to a plain CallToolResult.
24+
# The transparent path. On the modern wire the server augments this
25+
# tools/call into a task (we declared the extension) and Client.call_tool
26+
# polls tasks/get to the final result; on a legacy connection (today's
27+
# stdio) the extension cannot be negotiated, the server must not augment,
28+
# and the very same call simply returns the plain CallToolResult.
29+
result = await client.call_tool("render_report", {"title": "Q3", "sections": 2})
30+
assert isinstance(result.content[0], types.TextContent), result
31+
assert result.content[0].text.startswith("# Q3"), result
32+
# No 2025-style related-task _meta either; the task plumbing never leaks
33+
# into the surfaced result.
34+
assert result.meta is None, result
35+
3636
if client.server_capabilities.extensions is None:
37-
result = await client.call_tool("render_report", {"title": "Q3", "sections": 2})
38-
assert isinstance(result.content[0], types.TextContent), result
39-
assert result.content[0].text.startswith("# Q3"), result
40-
# No 2025-style related-task _meta either; SEP-2663 augmentation would
41-
# have replaced the whole result, failing CallToolResult parsing above.
42-
assert result.meta is None, result
37+
# Legacy wire: nothing more to show — the degradation above is the point.
4338
return
4439
assert client.server_capabilities.extensions == {EXTENSION_ID: {}}
4540

46-
# The server augments this tools/call into a task because we declared the extension.
47-
call = types.CallToolRequest(
48-
params=types.CallToolRequestParams(name="render_report", arguments={"title": "Q3", "sections": 2})
41+
# The manual leg: the same flow driven by hand on the raw wire.
42+
# allow_create_task=True hands back the typed CreateTaskResult instead of
43+
# polling, and the shared SEP-2663 request wrappers fetch the outcome.
44+
created = await client.session.call_tool(
45+
"render_report", {"title": "Q3", "sections": 1}, allow_create_task=True
4946
)
50-
created = await _send(client.session, call)
51-
assert created["resultType"] == "task", created
52-
task_id = created["taskId"]
47+
assert isinstance(created, CreateTaskResult), created
5348

54-
task = await _send(client.session, _GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)))
55-
assert task["status"] == "completed", task
56-
assert task["result"]["content"][0]["text"].startswith("# Q3"), task
49+
task = await client.session.send_request(
50+
cast("types.ClientRequest", GetTaskRequest(params=GetTaskRequestParams(task_id=created.task_id))),
51+
GetTaskResult,
52+
)
53+
assert task.status == "completed", task
54+
assert task.result is not None, task
55+
assert task.result["content"][0]["text"].startswith("# Q3"), task
5756

5857

5958
if __name__ == "__main__":

src/mcp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from mcp_types import Role as SamplingRole
6060

6161
from .client._input_required import InputRequiredRoundsExceededError
62+
from .client._tasks import TaskCancelledError, TaskFailedError, TaskInputRequiredError
6263
from .client.client import Client
6364
from .client.session import ClientSession
6465
from .client.session_group import ClientSessionGroup
@@ -128,6 +129,9 @@
128129
"StdioServerParameters",
129130
"StopReason",
130131
"SubscribeRequest",
132+
"TaskCancelledError",
133+
"TaskFailedError",
134+
"TaskInputRequiredError",
131135
"Tool",
132136
"ToolChoice",
133137
"ToolResultContent",

src/mcp/client/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""MCP Client module."""
22

33
from mcp.client._input_required import InputRequiredRoundsExceededError
4+
from mcp.client._tasks import TaskCancelledError, TaskFailedError, TaskInputRequiredError
45
from mcp.client._transport import Transport
56
from mcp.client.caching import (
67
CacheConfig,
@@ -25,5 +26,8 @@
2526
"InMemoryResponseCacheStore",
2627
"InputRequiredRoundsExceededError",
2728
"ResponseCacheStore",
29+
"TaskCancelledError",
30+
"TaskFailedError",
31+
"TaskInputRequiredError",
2832
"Transport",
2933
]

src/mcp/client/_tasks.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""SEP-2663 client-side task polling driver.
2+
3+
When a server augments a `tools/call` into a task — a `CreateTaskResult` in
4+
place of the `CallToolResult` — the client polls `tasks/get` until the task
5+
reaches a terminal status and surfaces only the final result. SEP-2663 advises
6+
exactly this shape: "existing code returning a fixed shape ... can transparently
7+
drive the polling flow internally and surface only the final, completed result".
8+
This module implements that loop as a pure function so it stays testable with
9+
plain closures; `Client` builds the `get_task` closure over its session,
10+
`ClientSession` stays mechanics-only (mirroring `_input_required`).
11+
"""
12+
13+
from __future__ import annotations
14+
15+
from collections.abc import Awaitable, Callable
16+
17+
import anyio
18+
from mcp_types import CallToolResult, ErrorData
19+
20+
from mcp.shared.exceptions import MCPError
21+
from mcp.shared.tasks import CreateTaskResult, GetTaskResult
22+
23+
DEFAULT_POLL_INTERVAL_SECONDS = 1.0
24+
"""Poll cadence when neither the snapshot nor the `CreateTaskResult` carries `pollIntervalMs`.
25+
26+
SEP-2663 makes the hint optional and only says clients SHOULD honor it when
27+
present; one second is the SDK's conservative default in its absence.
28+
"""
29+
30+
31+
class TaskFailedError(MCPError):
32+
"""The task reached `failed`: a JSON-RPC error occurred during execution (SEP-2663).
33+
34+
Carries the JSON-RPC error inlined on `tasks/get` as `code`/`message`/`data`,
35+
plus the snapshot's optional `statusMessage` diagnostic.
36+
"""
37+
38+
def __init__(self, error: ErrorData, status_message: str | None = None) -> None:
39+
super().__init__(code=error.code, message=error.message, data=error.data)
40+
self.status_message = status_message
41+
42+
43+
class TaskCancelledError(RuntimeError):
44+
"""The task reached `cancelled` before producing a result (SEP-2663)."""
45+
46+
def __init__(self, task_id: str, status_message: str | None = None) -> None:
47+
detail = f": {status_message}" if status_message is not None else ""
48+
super().__init__(f"Task {task_id!r} was cancelled{detail}")
49+
self.task_id = task_id
50+
self.status_message = status_message
51+
52+
53+
class TaskInputRequiredError(RuntimeError):
54+
"""The task reached `input_required`, which this driver does not drive yet.
55+
56+
SEP-2663's in-task input loop (fulfil `inputRequests` via `tasks/update`) is
57+
a deferred follow-up in this SDK. Drive it manually: poll with
58+
`mcp.shared.tasks.GetTaskRequest` and answer with
59+
`mcp.shared.tasks.UpdateTaskRequest` over `session.send_request`.
60+
"""
61+
62+
def __init__(self, task_id: str) -> None:
63+
super().__init__(
64+
f"Task {task_id!r} requires in-task input (status `input_required`); the SDK's automatic "
65+
"in-task input loop is not implemented yet. Drive it manually with the `mcp.shared.tasks` "
66+
"request wrappers (`GetTaskRequest`/`UpdateTaskRequest`) over `session.send_request`."
67+
)
68+
self.task_id = task_id
69+
70+
71+
async def run_task_driver(
72+
created: CreateTaskResult,
73+
*,
74+
get_task: Callable[[str], Awaitable[GetTaskResult]],
75+
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
76+
) -> CallToolResult:
77+
"""Poll a `CreateTaskResult` to its final `CallToolResult`.
78+
79+
Polls `tasks/get` (via `get_task`) until the task reaches a terminal status.
80+
Between polls it honors the SEP-2663 `pollIntervalMs` hint: each non-terminal
81+
snapshot sleeps its own `poll_interval_ms`, falling back to the
82+
`CreateTaskResult`'s, then to `DEFAULT_POLL_INTERVAL_SECONDS`.
83+
84+
The loop deliberately imposes no round cap or deadline of its own: SEP-2663
85+
tasks represent unbounded server-side work, so how long to wait is the
86+
caller's policy — cancel via an enclosing anyio cancel scope, or bound each
87+
`tasks/get` round with the session read timeout the `get_task` closure
88+
carries.
89+
90+
Args:
91+
created: The `CreateTaskResult` the augmented request returned.
92+
get_task: Sends one `tasks/get` for the given task id and returns the
93+
parsed `GetTaskResult` snapshot.
94+
sleep: Awaits the given number of seconds between polls (injectable for
95+
deterministic tests).
96+
97+
Raises:
98+
TaskFailedError: The task reached `failed`; carries the inlined JSON-RPC error.
99+
TaskCancelledError: The task reached `cancelled`.
100+
TaskInputRequiredError: The task reached `input_required` (the in-task
101+
input loop is not implemented yet).
102+
RuntimeError: The server violated SEP-2663 — a `completed` snapshot
103+
without `result`, or a `failed` snapshot without `error`.
104+
"""
105+
while True:
106+
snapshot = await get_task(created.task_id)
107+
if snapshot.status == "completed":
108+
if snapshot.result is None:
109+
raise RuntimeError(
110+
f"Task {created.task_id!r} is `completed` but carries no `result` (SEP-2663 violation)"
111+
)
112+
return CallToolResult.model_validate(snapshot.result, by_name=False)
113+
if snapshot.status == "failed":
114+
if snapshot.error is None:
115+
raise RuntimeError(f"Task {created.task_id!r} is `failed` but carries no `error` (SEP-2663 violation)")
116+
raise TaskFailedError(ErrorData.model_validate(snapshot.error), snapshot.status_message)
117+
if snapshot.status == "cancelled":
118+
raise TaskCancelledError(created.task_id, snapshot.status_message)
119+
if snapshot.status == "input_required":
120+
raise TaskInputRequiredError(created.task_id)
121+
interval_ms = snapshot.poll_interval_ms if snapshot.poll_interval_ms is not None else created.poll_interval_ms
122+
await sleep(DEFAULT_POLL_INTERVAL_SECONDS if interval_ms is None else interval_ms / 1000)

0 commit comments

Comments
 (0)