diff --git a/docs/community/troubleshooting/index.md b/docs/community/troubleshooting/index.md index d3f5e07fe..680a0c92d 100644 --- a/docs/community/troubleshooting/index.md +++ b/docs/community/troubleshooting/index.md @@ -127,6 +127,8 @@ MCP and LSP toolsets are managed by a supervisor that auto-restarts them when th - `/tools` — the unified tools dialog. Its top section lists every toolset with its current state (`Stopped`, `Starting`, `Ready`, `Degraded`, `Restarting`, `Failed`), restart count, and last error; the bottom section lists every tool the agent can call. Start here whenever a tool seems missing or stuck. - `/toolset-restart ` — force a supervisor-driven reconnect of the named toolset. Useful after completing OAuth, when a remote MCP server has been redeployed, or when a language server like `gopls` is unresponsive. +Remote MCP servers that return `401 invalid_token` (e.g. because the stored OAuth token was revoked or rotated) are now self-healing: docker-agent silently exchanges the refresh token for a new one when possible, or surfaces an OAuth re-authentication prompt on your next message when refresh is not possible. No more stuck toolsets that require a process restart — but if you want to trigger re-auth immediately, `/toolset-restart ` forces it right away. + MCP tools using stdio transport must complete the initialization handshake before becoming available. If tools fail silently: 1. Run `/tools` to see whether the toolset is `Failed` or stuck in `Restarting`, and what the last error was. diff --git a/docs/features/remote-mcp/index.md b/docs/features/remote-mcp/index.md index 999b7ec7e..2967eba53 100644 --- a/docs/features/remote-mcp/index.md +++ b/docs/features/remote-mcp/index.md @@ -64,6 +64,17 @@ Set `allow_private_ips: true` on a remote MCP toolset only when the MCP server o

Remote MCP connections (Streamable HTTP / SSE) automatically reconnect after the server closes an idle connection — no configuration needed. Services like Notion and Linear close idle connections periodically; docker-agent detects the clean close and reconnects with exponential backoff. To tune reconnect behaviour or disable reconnection entirely, use the lifecycle block.

+
+
Automatic recovery from revoked or rotated OAuth tokens +
+

If a remote MCP server rejects the cached token with a 401 invalid_token error (for example, because the token was revoked or rotated server-side), docker-agent handles the failure automatically:

+
    +
  • Silent refresh: when a refresh token is available, docker-agent silently exchanges it for a new access token and replays the request — no user interaction required.
  • +
  • Re-authentication prompt: when the refresh token is absent or has also expired, the toolset transitions to a "needs re-auth" state and surfaces an OAuth prompt on your next message (exactly like the first-time flow).
  • +
+

Either way, the agent never burns 5 reconnect attempts on an auth failure — it fails fast and either refreshes silently or defers to interactive re-auth. If you want to trigger re-auth immediately without waiting for the next message, run /toolset-restart <name> from the TUI.

+
+ ### OAuth for servers without Dynamic Client Registration Most remote MCP servers that require OAuth support [Dynamic Client Registration (RFC 7591)]({{ 'https://datatracker.ietf.org/doc/html/rfc7591' }}) — no configuration is needed, docker-agent handles the flow for you. diff --git a/docs/tools/mcp/index.md b/docs/tools/mcp/index.md index 50a1fc5fc..1476a56ac 100644 --- a/docs/tools/mcp/index.md +++ b/docs/tools/mcp/index.md @@ -79,7 +79,7 @@ toolsets: ## Remote MCP (Streamable HTTP / SSE) -Connect to MCP servers over the network. OAuth flows (including [Dynamic Client Registration](https://datatracker.ietf.org/doc/html/rfc7591)) are handled automatically — docker-agent opens your browser when authentication is required and caches tokens for subsequent sessions. +Connect to MCP servers over the network. OAuth flows (including [Dynamic Client Registration](https://datatracker.ietf.org/doc/html/rfc7591)) are handled automatically — docker-agent opens your browser when authentication is required and caches tokens for subsequent sessions. Tokens are refreshed silently when they expire or are revoked server-side; if a silent refresh is not possible, the OAuth prompt reappears on the next message. ```yaml toolsets: diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 802024355..1d6751950 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -15,6 +15,7 @@ import ( "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/tools" + mcptools "github.com/docker/docker-agent/pkg/tools/mcp" ) // Agent represents an AI agent @@ -409,6 +410,18 @@ func (a *Agent) ensureToolSetsAreStarted(ctx context.Context) { continue } desc := tools.DescribeToolSet(toolSet) + if mcptools.IsAuthorizationRequired(err) { + // Recovery: previously-working toolset lost its OAuth token in the + // background. Emit the targeted re-auth notice once per streak so the + // user knows a dialog will appear on their next message. + // Initial-startup auth deferral (ShouldReportRecoveryFailure==false) + // stays silent — the dialog appears naturally on the first turn. + if toolSet.ShouldReportRecoveryFailure() { + slog.WarnContext(ctx, "Toolset needs re-authentication after background token rejection", "agent", a.Name(), "toolset", desc) + a.AddToolWarning(desc + " needs re-authentication — it will prompt on your next message, or use /toolset-restart") + } + continue + } if toolSet.ShouldReportFailure() { slog.WarnContext(ctx, "Toolset start failed; will retry on next turn", "agent", a.Name(), "toolset", desc, "error", err) a.AddToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index ce5c38b7e..0a7c5634e 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -1243,41 +1243,53 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen isLast := i == totalToolsets-1 - // Start the toolset if needed + // Start the toolset if needed, including recovery: a previously-started + // toolset whose inner connection died (e.g. background invalid_token) + // must have its recovery Start() called here so ShouldReportRecoveryFailure + // can fire the targeted re-auth notice. Start() is a no-op when the + // toolset is already healthy, so calling it unconditionally is safe. if startable, ok := toolset.(*tools.StartableToolSet); ok { - if !startable.IsStarted() { - if err := startable.Start(ctx); err != nil { - desc := tools.DescribeToolSet(startable.ToolSet) - // IsAuthorizationRequired must be checked BEFORE - // ShouldReportFailure: this is the first — expected — - // failure of a deferred-OAuth toolset, and consuming the - // failure-reported flag here would suppress the *real* - // failure (e.g. server 4xx on the eventual interactive - // retry) that the user actually needs to see. - if mcptools.IsAuthorizationRequired(err) { - // The toolset just needs an OAuth approval that we - // deliberately deferred until the user is interacting - // with the agent. The dialog will appear naturally on - // the first RunStream — no need to pre-announce it. + if err := startable.Start(ctx); err != nil { + desc := tools.DescribeToolSet(startable.ToolSet) + // IsAuthorizationRequired must be checked BEFORE + // ShouldReportFailure: this is the first — expected — + // failure of a deferred-OAuth toolset, and consuming the + // failure-reported flag here would suppress the *real* + // failure (e.g. server 4xx on the eventual interactive + // retry) that the user actually needs to see. + if mcptools.IsAuthorizationRequired(err) { + // Two cases: + // 1. Initial startup deferral (toolset never ran): the + // OAuth dialog will appear naturally on the first user + // message — no need to pre-announce it. + // 2. Recovery: the toolset was previously working but the + // background watcher detected a server-side invalid_token + // (fixes #3198). Surface a deduped re-auth notice so the + // user knows what is about to prompt on their next message. + if startable.ShouldReportRecoveryFailure() { + slog.WarnContext(ctx, "Toolset needs re-authentication after background token rejection", + "agent", a.Name(), "toolset", desc) + a.AddToolWarning(desc + " needs re-authentication — it will prompt on your next message, or use /toolset-restart") + } else { slog.DebugContext(ctx, "Toolset deferred until first message", "agent", a.Name(), "toolset", desc, "reason", err) - continue } - // Route real failures through the agent's warning - // channel so the TUI surfaces a persistent, - // user-visible notice that includes the actual - // server-side cause (threaded through by - // remoteMCPClient.Initialize). Use the same - // once-per-streak guard as ensureToolSetsAreStarted - // so a failing toolset doesn't flood the UI with a - // new warning every time the agent is restarted. - if !startable.ShouldReportFailure() { - slog.DebugContext(ctx, "Toolset still unavailable; skipping", "agent", a.Name(), "toolset", desc, "error", err) - continue - } - slog.WarnContext(ctx, "Toolset start failed; skipping", "agent", a.Name(), "toolset", desc, "error", err) - a.AddToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) continue } + // Route real failures through the agent's warning + // channel so the TUI surfaces a persistent, + // user-visible notice that includes the actual + // server-side cause (threaded through by + // remoteMCPClient.Initialize). Use the same + // once-per-streak guard as ensureToolSetsAreStarted + // so a failing toolset doesn't flood the UI with a + // new warning every time the agent is restarted. + if !startable.ShouldReportFailure() { + slog.DebugContext(ctx, "Toolset still unavailable; skipping", "agent", a.Name(), "toolset", desc, "error", err) + continue + } + slog.WarnContext(ctx, "Toolset start failed; skipping", "agent", a.Name(), "toolset", desc, "error", err) + a.AddToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) + continue } } diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index bda6060c2..7ea4e73bc 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -1233,6 +1233,99 @@ func TestEmitStartupInfo_DeferredAuthDoesNotConsumeFailureGate(t *testing.T) { "and the user sees zero tools with no explanation") } +// recoveryAuthToolSet simulates a toolset whose first Start() always succeeds, +// and whose Restart() returns a configurable error (used to simulate a +// background invalid_token loss after a prior successful start). +// IsStarted() reflects live connection state so StartableToolSet.Start() can +// detect the "inner went dead" recovery scenario. +type recoveryAuthToolSet struct { + started bool + restartErr error +} + +func (r *recoveryAuthToolSet) Tools(context.Context) ([]tools.Tool, error) { return nil, nil } +func (r *recoveryAuthToolSet) Start(context.Context) error { r.started = true; return nil } +func (r *recoveryAuthToolSet) Stop(context.Context) error { r.started = false; return nil } +func (r *recoveryAuthToolSet) IsStarted() bool { return r.started } +func (r *recoveryAuthToolSet) Restart(context.Context) error { return r.restartErr } + +// TestEmitStartupInfo_RecoveryAuthNoticeEmittedOnce is the regression test for +// blocking issue 3: when a toolset was previously started and working but the +// background watcher detected a server-side invalid_token, the next call to +// emitToolsProgressively must attempt a recovery Start() and emit exactly one +// targeted re-auth notice. Initial-startup auth deferral (toolset never worked +// before) must remain silent. The streak resets on success so a subsequent +// background failure produces a fresh notice. +func TestEmitStartupInfo_RecoveryAuthNoticeEmittedOnce(t *testing.T) { + prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} + authErr := &mcptools.AuthorizationRequiredError{URL: "https://example.test/mcp"} + + inner := &recoveryAuthToolSet{restartErr: authErr} + root := agent.New("root", "agent", + agent.WithModel(prov), + agent.WithToolSets(inner), + ) + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + var wrapped *tools.StartableToolSet + for _, ts := range root.ToolSets() { + if s, ok := ts.(*tools.StartableToolSet); ok { + wrapped = s + break + } + } + require.NotNil(t, wrapped, "agent.ToolSets() must wrap the inner toolset in a *tools.StartableToolSet") + + // nopSend discards sidebar events; we inspect agent.DrainWarnings() instead. + nopSend := func(Event) bool { return true } + // Mirror EmitStartupInfo\'s non-interactive context so toolsets with OAuth + // fail fast rather than blocking on a prompt. + ctx := mcptools.WithoutInteractivePrompts(t.Context()) + + // Phase 1: initial startup — inner.Start() succeeds (first call); no recovery + // notice because the toolset was never previously working. + rt.emitToolsProgressively(ctx, root, nopSend) + _ = root.DrainWarnings() // clear any unrelated warnings + require.True(t, wrapped.IsStarted(), "toolset must be started after initial success") + + // Phase 2: background failure — inner loses its connection (e.g. server-side + // invalid_token eviction set the live started flag to false). + inner.started = false + + // First emitToolsProgressively after the background failure: recovery Start() + // is attempted (Restart returns authErr), and exactly one targeted notice is + // added to the agent\'s warning queue. + rt.emitToolsProgressively(ctx, root, nopSend) + noticesPhase2 := root.DrainWarnings() + require.Len(t, noticesPhase2, 1, + "exactly one targeted re-auth notice must be emitted on the first recovery failure") + assert.Contains(t, noticesPhase2[0], "needs re-authentication", + "recovery notice must use the targeted re-auth framing, not the generic start-failed message") + + // Dedup: ShouldReportRecoveryFailure was consumed by emitToolsProgressively; + // a direct call must return false (streak is still active but pending cleared). + assert.False(t, wrapped.ShouldReportRecoveryFailure(), + "ShouldReportRecoveryFailure must return false after the first notice was emitted (dedup)") + + // Phase 3: inner recovers — successful Start() (via inner.Start() since + // wrapped.started==false after failed Restart) resets the recovery streak. + inner.started = true + rt.emitToolsProgressively(ctx, root, nopSend) + _ = root.DrainWarnings() + require.True(t, wrapped.IsStarted(), "toolset must be re-started after recovery") + assert.False(t, wrapped.ShouldReportRecoveryFailure(), + "recovery streak must be reset after a successful Start") + + // Phase 4: background failure again — streak was reset, so a fresh notice + // is expected (verifies reset-on-success behavior). + inner.started = false + rt.emitToolsProgressively(ctx, root, nopSend) + noticesPhase4 := root.DrainWarnings() + require.Len(t, noticesPhase4, 1, "fresh failure after streak reset must emit a new notice") +} + // TestEmitAgentWarnings_OnlyEmitsFailures verifies that emitAgentWarnings // only surfaces real failures to the user. Recovery is intentionally // silent: a previously-failed toolset becoming available again does NOT diff --git a/pkg/tools/lifecycle/classify.go b/pkg/tools/lifecycle/classify.go index c72d0375e..8e37eb95d 100644 --- a/pkg/tools/lifecycle/classify.go +++ b/pkg/tools/lifecycle/classify.go @@ -76,6 +76,15 @@ func classifyByMessage(err error) error { strings.Contains(lower, "broken pipe"), strings.Contains(msg, "EOF"): return wrap(ErrTransport, err) + // Map server-side OAuth token rejection to ErrAuthRequired. We match + // "invalid_token" (RFC 6750 §3.1 canonical error code) and its space- + // separated variant. We deliberately do NOT match bare "unauthorized" + // here to avoid classifying application-level 401s (unrelated to OAuth) + // as permanent auth failures; the token-was-attached gating in + // oauthTransport.roundTrip is the correct place for that check. + case strings.Contains(lower, "invalid_token"), + strings.Contains(lower, "invalid token"): + return wrap(ErrAuthRequired, err) } return err } diff --git a/pkg/tools/lifecycle/classify_test.go b/pkg/tools/lifecycle/classify_test.go index e7ebbfd21..059d52de5 100644 --- a/pkg/tools/lifecycle/classify_test.go +++ b/pkg/tools/lifecycle/classify_test.go @@ -81,6 +81,43 @@ func TestClassify_AlreadyClassifiedPasses(t *testing.T) { assert.Check(t, errors.Is(got, lifecycle.ErrAuthRequired)) } +func TestClassify_InvalidToken(t *testing.T) { + t.Parallel() + cases := []struct { + name string + msg string + }{ + {"rfc6750_error_code", `401 Unauthorized: {"error":"invalid_token","error_description":"Invalid access token"}`}, + {"space_variant", "server rejected token: invalid token"}, + {"upper_case", "INVALID_TOKEN: token expired"}, + } + for _, tc := range cases { + got := lifecycle.Classify(errors.New(tc.msg)) + assert.Check(t, errors.Is(got, lifecycle.ErrAuthRequired), "msg=%q", tc.msg) + assert.Check(t, lifecycle.IsPermanent(got), "msg=%q: must be permanent", tc.msg) + } +} + +func TestClassify_BareUnauthorizedIsNotAuth(t *testing.T) { + t.Parallel() + // A bare "unauthorized" without "invalid_token" must NOT be classified as + // ErrAuthRequired to avoid misreading application-level 401s as permanent + // auth failures (see human decision Q3 in the implementation plan). + got := lifecycle.Classify(errors.New("401 Unauthorized")) + assert.Check(t, !errors.Is(got, lifecycle.ErrAuthRequired), "bare unauthorized must not map to ErrAuthRequired") +} + +func TestClassify_InvalidToken_Idempotent(t *testing.T) { + t.Parallel() + // Classify must be idempotent: an already-wrapped ErrAuthRequired that + // also contains "invalid_token" in its message must not be double-wrapped. + inner := errors.New("invalid_token: expired") + first := lifecycle.Classify(inner) + second := lifecycle.Classify(first) + assert.Check(t, errors.Is(second, lifecycle.ErrAuthRequired)) + assert.Check(t, errors.Is(second, inner)) +} + func TestClassify_UnknownPassthrough(t *testing.T) { t.Parallel() in := errors.New("totally unrelated") diff --git a/pkg/tools/lifecycle/supervisor.go b/pkg/tools/lifecycle/supervisor.go index 5ddc48eea..55a4ec9aa 100644 --- a/pkg/tools/lifecycle/supervisor.go +++ b/pkg/tools/lifecycle/supervisor.go @@ -10,6 +10,31 @@ import ( "time" ) +// backgroundReconnectKey is a context key that the supervisor attaches to +// connector.Connect calls made during background watcher reconnect attempts, +// distinguishing them from the initial interactive Start. Connector +// implementations (e.g. the MCP clientConnector) use this to apply +// non-interactive constraints on background reconnects so a 401 defers +// cleanly rather than blocking on a dead elicitation bridge. +type backgroundReconnectKey struct{} + +// withBackgroundReconnect returns a copy of ctx marked as a background +// reconnect attempt. It is set by tryRestart before calling +// connector.Connect so the connector can distinguish watcher reconnects +// from the initial interactive Start. +func withBackgroundReconnect(ctx context.Context) context.Context { + return context.WithValue(ctx, backgroundReconnectKey{}, true) +} + +// IsBackgroundReconnect reports whether ctx was created by the supervisor +// for a background reconnect attempt. Connector.Connect implementations can +// use this to disable interactive operations (e.g. OAuth prompts) that +// should not run in the background. +func IsBackgroundReconnect(ctx context.Context) bool { + v, _ := ctx.Value(backgroundReconnectKey{}).(bool) + return v +} + // Connector creates new sessions for a Supervisor. Implementations are // transport-specific: stdio MCP, remote MCP, LSP stdio. type Connector interface { @@ -462,8 +487,21 @@ func (s *Supervisor) tryRestart(ctx context.Context) bool { } s.mu.Unlock() - sess, err := s.connector.Connect(ctx) + sess, err := s.connector.Connect(withBackgroundReconnect(ctx)) if err != nil { + // A permanent error on reconnect (e.g. ErrAuthRequired from a + // server-side invalid_token) must not be retried: doing so would + // burn through the budget and mask the real failure. Symmetric + // with the shouldRestart check on the Wait() path. + if IsPermanent(err) { + log.Warn("supervisor: permanent error on reconnect; not retrying", "name", s.name, "error", err) + s.tracker.Fail(StateFailed, err) + if cb := s.policy.OnFailed; cb != nil { + cb(err) + } + s.signalDone() + return false + } s.tracker.Fail(StateRestarting, err) s.tracker.IncRestarts() log.Warn("supervisor: restart failed", "name", s.name, "attempt", attempt+1, "error", err) diff --git a/pkg/tools/lifecycle/supervisor_test.go b/pkg/tools/lifecycle/supervisor_test.go index 407505993..5b7f4e8d6 100644 --- a/pkg/tools/lifecycle/supervisor_test.go +++ b/pkg/tools/lifecycle/supervisor_test.go @@ -451,6 +451,54 @@ func TestSupervisor_PermanentErrorsDontRestart(t *testing.T) { assert.Check(t, is.Equal(c.Calls(), 1), "must not retry on permanent error") } +// TestSupervisor_PermanentConnectErrorDoesNotRetry verifies that when +// connector.Connect returns a permanent error during a background reconnect +// attempt (e.g. ErrAuthRequired from a server-side invalid_token), the +// supervisor transitions to StateFailed immediately without burning through +// its MaxAttempts budget. +// +// This is the gap the bug exercised: the session Wait succeeded (server +// closed cleanly) but the subsequent reconnect Connect returned a permanent +// error that the old supervisor would retry N times before giving up. +func TestSupervisor_PermanentConnectErrorDoesNotRetry(t *testing.T) { + t.Parallel() + + // sess1 is the initial successful connection; then the reconnect + // returns a permanent auth error. + sess1 := newFakeSession() + c := newScriptedConnector( + scriptStep{session: sess1}, + scriptStep{err: lifecycle.ErrAuthRequired}, // permanent: must NOT burn MaxAttempts + ) + + failedCh := make(chan error, 1) + s := lifecycle.New("test", c, lifecycle.Policy{ + MaxAttempts: 5, // budget that must NOT be consumed + Backoff: fastBackoff, + OnFailed: func(err error) { + select { + case failedCh <- err: + default: + } + }, + }) + + assert.NilError(t, s.Start(t.Context())) + // Make the session fail non-permanently so tryRestart is entered. + sess1.fail(errors.New("transport closed")) + + select { + case got := <-failedCh: + assert.Check(t, errors.Is(got, lifecycle.ErrAuthRequired)) + case <-time.After(2 * time.Second): + t.Fatal("supervisor did not call OnFailed") + } + + assert.Check(t, is.Equal(s.State().State, lifecycle.StateFailed)) + // One initial Connect + one reconnect attempt that returned permanent error. + assert.Check(t, is.Equal(c.Calls(), 2), "must fail-fast after one reconnect attempt on permanent error") +} + func TestSupervisor_CleanClosePolicyBoundary(t *testing.T) { t.Parallel() diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index 26e2bf4ef..d829a9d10 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -482,6 +482,16 @@ type clientConnector struct { func (c *clientConnector) Connect(ctx context.Context) (lifecycle.Session, error) { ts := c.ts + // Background reconnects (watcher goroutine retrying after a disconnect) + // must not block on interactive flows such as OAuth prompts, because the + // elicitation bridge is not reliably available outside of a user turn. + // Marking the context non-interactive here ensures a background 401 + // returns AuthorizationRequiredError → lifecycle.ErrAuthRequired → + // StateFailed immediately, and the next interactive turn recovers cleanly. + if lifecycle.IsBackgroundReconnect(ctx) { + ctx = WithoutInteractivePrompts(ctx) + } + // The MCP toolset connection needs to persist beyond the initial HTTP // request that triggered its creation. When OAuth succeeds, subsequent // agent requests should reuse the already-authenticated MCP connection. @@ -558,6 +568,19 @@ func (c *clientConnector) Connect(ctx context.Context) (lifecycle.Session, error ) return nil, errServerUnavailable } + // Map auth-related transport failures to ErrAuthRequired so the + // supervisor treats them as permanent (no retry storm) and the + // toolset lands in StateFailed for clean interactive recovery on + // the next turn. Two signals to check: + // + // 1. AuthorizationRequiredError / OAuthDeclinedError from the + // transport: enrichConnectError re-emits these through the SDK's + // %v wrapping so errors.As still works. + // 2. lifecycle.Classify detected "invalid_token" in the error + // message, which also resolves to ErrAuthRequired. + if IsAuthorizationRequired(err) || IsOAuthDeclined(err) || errors.Is(classified, lifecycle.ErrAuthRequired) { + return nil, fmt.Errorf("%w: %w", lifecycle.ErrAuthRequired, err) + } slog.ErrorContext(ctx, "Failed to initialize MCP client", "error", err) return nil, fmt.Errorf("failed to initialize MCP client: %w", err) } diff --git a/pkg/tools/mcp/oauth.go b/pkg/tools/mcp/oauth.go index b363e8631..513994ac6 100644 --- a/pkg/tools/mcp/oauth.go +++ b/pkg/tools/mcp/oauth.go @@ -32,6 +32,19 @@ import ( // resourceMetadataFromWWWAuth extracts resource metadata URL from WWW-Authenticate header var re = regexp.MustCompile(`resource="([^"]+)"`) +// errorCodeRe extracts the RFC 6750 error= parameter from a WWW-Authenticate header. +var errorCodeRe = regexp.MustCompile(`error="([^"]+)"`) + +// errorCodeFromWWWAuth returns the RFC 6750 error code from a WWW-Authenticate +// header value (e.g. "invalid_token"), or an empty string when absent. +func errorCodeFromWWWAuth(wwwAuth string) string { + matches := errorCodeRe.FindStringSubmatch(wwwAuth) + if len(matches) == 2 { + return matches[1] + } + return "" +} + // unmanagedOAuthWaitTimeout is the upper bound on how long the unmanaged // OAuth flow blocks waiting for a reply (elicitation result or // out-of-band callback). Generous enough to accommodate a user clicking @@ -329,24 +342,78 @@ func (t *oauthTransport) oauthClient() *http.Client { return oauthHTTPClient } -func (t *oauthTransport) authorizeOnce(ctx context.Context, authServer, wwwAuth string) error { +// handleServerRejectedToken is called when the server returned 401 for a +// request that carried a Bearer token we believed valid. It attempts to +// silently recover by: +// +// 1. Taking oauthFlowMu to serialise concurrent initialize-stage RPCs that +// all hit the same 401 at roughly the same time. +// 2. Re-checking: if another goroutine already refreshed the token (its +// AccessToken differs from prev), return nil so roundTrip replays with +// the new token. +// 3. Evicting the stale token from the store. +// 4. Attempting a refresh-token grant (if prev.RefreshToken != ""). +// 5. On success: store the new token and return nil. +// 6. On failure / no refresh token: fall back to interactive OAuth when +// the context allows it, else return AuthorizationRequiredError. +// +// The isRetry flag in roundTrip prevents a second call to this handler +// within one request so there is no infinite recursion. +func (t *oauthTransport) handleServerRejectedToken(ctx context.Context, prev *OAuthToken, wwwAuth string) error { t.oauthFlowMu.Lock() defer t.oauthFlowMu.Unlock() - if token := t.getValidToken(ctx); token != nil { + // Coalesce: if another goroutine already refreshed successfully, the + // stored token is now different. Return nil so the caller replays. + if current, err := t.tokenStore.GetToken(t.baseURL); err == nil && current.AccessToken != prev.AccessToken { + slog.DebugContext(ctx, "Token already refreshed by concurrent request; reusing", "url", t.baseURL) return nil } - // Sticky decline: the MCP SDK's Connect() runs several - // initialize-stage RPCs concurrently. Each one that gets a 401 - // queues here on oauthFlowMu. Without this short-circuit, the - // goroutine that wins the mutex runs the full OAuth flow; when - // the user clicks Cancel, this goroutine returns OAuthDeclinedError - // and releases the mutex — at which point the NEXT queued - // goroutine sees no valid token and fires a fresh OAuth flow, - // re-popping the dialog the user just dismissed. Latching the - // decline state under the same mutex makes the queued callers - // see the prior decline before they can start a new flow. + // Evict the stale token; the refresh or interactive flow will store a + // fresh one. + if err := t.tokenStore.RemoveToken(t.baseURL); err != nil { + slog.DebugContext(ctx, "Failed to evict stale token", "url", t.baseURL, "error", err) + } + + // Attempt a silent refresh when we have a refresh token. + if prev.RefreshToken != "" { + _, err := t.refreshStoredToken(ctx, prev) + if err == nil { + slog.DebugContext(ctx, "Silently refreshed server-rejected token", "url", t.baseURL) + t.client.oauthSuccess() + return nil + } + slog.DebugContext(ctx, "Refresh failed after server-side token rejection; falling back to interactive auth", + "url", t.baseURL, "error", err) + } + + // Refresh not possible or failed: fall back to interactive OAuth if the + // context allows it. + if !interactivePromptsAllowed(ctx) { + slog.DebugContext(ctx, "Non-interactive context: deferring re-auth after server-side token rejection", "url", t.baseURL) + t.mu.Lock() + t.lastAuthRequired = true + t.mu.Unlock() + return &AuthorizationRequiredError{URL: t.baseURL} + } + + // Route through startInteractiveFlowLocked so the sticky-decline latch is + // honored: a prior user cancel short-circuits here, and a new cancel is + // latched so concurrent callers queued on oauthFlowMu observe it too. + return t.startInteractiveFlowLocked(ctx, t.baseURL, wwwAuth) +} + +// startInteractiveFlowLocked runs the interactive OAuth flow while oauthFlowMu +// is already held. It enforces the sticky-decline guard: a prior user cancel +// short-circuits immediately and returns OAuthDeclinedError, and a new cancel +// is latched so subsequent callers queued on oauthFlowMu observe it too. +// +// This is the single call-site for launching an interactive flow so that both +// authorizeOnce (first-contact 401) and handleServerRejectedToken (recovery +// after failed refresh) share the same decline-guard logic without risk of +// double-locking oauthFlowMu. +func (t *oauthTransport) startInteractiveFlowLocked(ctx context.Context, authServer, wwwAuth string) error { t.mu.Lock() declined := t.lastOAuthDeclined t.mu.Unlock() @@ -357,13 +424,11 @@ func (t *oauthTransport) authorizeOnce(ctx context.Context, authServer, wwwAuth err := t.handleOAuthFlow(ctx, authServer, wwwAuth) if err != nil { - // Latch the decline state BEFORE the deferred Unlock fires so - // any goroutine queued on oauthFlowMu observes it on its next - // iteration of the getValidToken / declined / handleOAuthFlow - // dance. Setting this in roundTrip (after we return) would - // race: the queued goroutine would acquire the mutex first - // and start a fresh flow while we are still bubbling the - // error up the stack. + // Latch the decline state BEFORE the deferred Unlock on oauthFlowMu + // fires so any goroutine queued on oauthFlowMu observes it on its + // next iteration. Setting this after returning would race: the queued + // goroutine could acquire the mutex first and start a fresh flow while + // we are still bubbling the error up the stack. var declinedErr *OAuthDeclinedError if errors.As(err, &declinedErr) { t.mu.Lock() @@ -374,6 +439,22 @@ func (t *oauthTransport) authorizeOnce(ctx context.Context, authServer, wwwAuth return err } +func (t *oauthTransport) authorizeOnce(ctx context.Context, authServer, wwwAuth string) error { + t.oauthFlowMu.Lock() + defer t.oauthFlowMu.Unlock() + + if token := t.getValidToken(ctx); token != nil { + return nil + } + + // Sticky decline: the MCP SDK's Connect() runs several initialize-stage + // RPCs concurrently. Each one that gets a 401 queues here on oauthFlowMu. + // startInteractiveFlowLocked checks the latch so concurrent callers that + // arrive after a user-cancel observe the prior decline and short-circuit + // without re-popping the dialog. + return t.startInteractiveFlowLocked(ctx, authServer, wwwAuth) +} + func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Response, error) { var bodyBytes []byte if req.Body != nil && req.Body != http.NoBody { @@ -388,8 +469,10 @@ func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Respo reqClone := req.Clone(req.Context()) // Attach a valid token if available, silently refreshing if expired. + var attachedToken *OAuthToken if token := t.getValidToken(req.Context()); token != nil { reqClone.Header.Set("Authorization", "Bearer "+token.AccessToken) + attachedToken = token } resp, err := t.base.RoundTrip(reqClone) @@ -400,6 +483,42 @@ func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Respo if resp.StatusCode == http.StatusUnauthorized && !isRetry { wwwAuth := resp.Header.Get("WWW-Authenticate") if wwwAuth != "" { + // If a Bearer token was attached and the server is signalling a + // credential rejection (RFC 6750 invalid_token or any 401 against + // a token we believed valid), attempt silent eviction + refresh + // before falling back to interactive OAuth. This handles the common + // "token was rotated/revoked server-side" case without user + // interaction. + if attachedToken != nil { + errorCode := errorCodeFromWWWAuth(wwwAuth) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(body)) + serverMsg := extractServerMessage(body) + // Signal is: RFC 6750 error="invalid_token" in the header, or + // "invalid_token" / "invalid token" in the response body. + isInvalidToken := strings.Contains(strings.ToLower(errorCode), "invalid_token") || + strings.Contains(strings.ToLower(serverMsg), "invalid_token") || + strings.Contains(strings.ToLower(serverMsg), "invalid token") + if isInvalidToken { + if len(bodyBytes) > 0 { + req.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) + } + if err := t.handleServerRejectedToken(req.Context(), attachedToken, wwwAuth); err != nil { + // Refresh or re-auth deferred; caller will surface the error. + return nil, err + } + // Token refreshed successfully; replay the request. + return t.roundTrip(req, true) + } + // Token was attached but the server returned a bare 401 without + // invalid_token. This is an app-level authorization failure + // (wrong permissions, revoked app access, etc.) the transport + // cannot silently recover from. Return the 401 as-is: no + // eviction, no /token call, stored credential untouched. + return resp, nil + } + // If the caller asked for non-interactive operation (e.g. the // runtime is populating sidebar tool counts during startup), // don't block on an OAuth elicitation that the TUI is not yet @@ -657,13 +776,28 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { return nil } + newToken, err := t.refreshStoredToken(ctx, token) + if err != nil { + return nil + } + return newToken +} + +// refreshStoredToken attempts a silent refresh-token grant for prev. It +// honours the 30-second refreshFailedAt backoff to avoid hammering the +// token endpoint on repeated failures, and resets it on success. +// +// On success the new token is stored and returned; on failure nil and the +// error are returned and refreshFailedAt is stamped. The caller is +// responsible for ensuring prev.RefreshToken is non-empty before calling. +func (t *oauthTransport) refreshStoredToken(ctx context.Context, prev *OAuthToken) (*OAuthToken, error) { // Avoid hammering the token endpoint if a recent refresh already failed. const refreshBackoff = 30 * time.Second t.mu.Lock() failedAt := t.refreshFailedAt t.mu.Unlock() if !failedAt.IsZero() && time.Since(failedAt) < refreshBackoff { - return nil + return nil, fmt.Errorf("skipping refresh: last attempt failed %s ago", time.Since(failedAt).Round(time.Second)) } slog.DebugContext(ctx, "Attempting silent token refresh", "url", t.baseURL) @@ -688,23 +822,23 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { defer refreshSpan.End() o := &oauth{metadataClient: t.oauthClient()} - authServer := cmp.Or(token.AuthServer, t.baseURL) + authServer := cmp.Or(prev.AuthServer, t.baseURL) metadata, err := o.getAuthorizationServerMetadata(ctx, authServer) if err != nil { slog.DebugContext(ctx, "Failed to fetch auth server metadata for refresh", "auth_server", authServer, "error", err) refreshSpan.RecordError(err) refreshSpan.SetStatus(codes.Error, "metadata fetch failed") refreshSpan.SetAttributes(attribute.String("error.type", "metadata")) - return nil + return nil, err } newToken, err := refreshAccessToken( ctx, t.oauthClient(), metadata.TokenEndpoint, - token.RefreshToken, - token.ClientID, - token.ClientSecret, + prev.RefreshToken, + prev.ClientID, + prev.ClientSecret, ) if err != nil { slog.DebugContext(ctx, "Token refresh failed, will require interactive auth", "error", err) @@ -714,10 +848,10 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { t.mu.Lock() t.refreshFailedAt = time.Now() t.mu.Unlock() - return nil + return nil, err } newToken.AuthServer = authServer - newToken.RequestedScopes = token.RequestedScopes + newToken.RequestedScopes = prev.RequestedScopes t.mu.Lock() t.refreshFailedAt = time.Time{} // reset on success @@ -728,7 +862,7 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { } slog.DebugContext(ctx, "Token refreshed successfully", "url", t.baseURL) - return newToken + return newToken, nil } // tokenCoversConfiguredScopes reports whether the stored token was obtained diff --git a/pkg/tools/mcp/oauth_test.go b/pkg/tools/mcp/oauth_test.go index 3b3b1f196..189e2bca6 100644 --- a/pkg/tools/mcp/oauth_test.go +++ b/pkg/tools/mcp/oauth_test.go @@ -2054,3 +2054,483 @@ func TestUnmanagedRedirectURI_PerToolsetTakesPrecedence(t *testing.T) { transport.unmanagedOAuthRedirectURI = "" assert.Empty(t, transport.unmanagedRedirectURI()) } + +// --------- Server-side invalid_token eviction + refresh tests --------- + +// newInvalidTokenTestServer creates an httptest mux emulating a server that: +// - Returns 401 + WWW-Authenticate: Bearer error="invalid_token" for the +// stale bearer token "old-at". +// - Returns 200 for any request bearing a valid bearer token "fresh-at". +// - Serves OAuth authorization-server metadata and a /token refresh endpoint. +func newInvalidTokenTestServer(t *testing.T) (*httptest.Server, *atomic.Int32) { + t.Helper() + var tokenCalls atomic.Int32 + mux := http.NewServeMux() + + // Use NewUnstartedServer so we can reference srv.URL in handler closures + // before the server is started. + srv := httptest.NewUnstartedServer(mux) + srv.Start() + t.Cleanup(srv.Close) + + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(AuthorizationServerMetadata{ + Issuer: srv.URL, + AuthorizationEndpoint: srv.URL + "/authorize", + TokenEndpoint: srv.URL + "/token", + }) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + tokenCalls.Add(1) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "fresh-at", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "fresh-rt", + }) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") == "Bearer fresh-at" { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + return + } + // All other tokens (including the stale "old-at") → invalid_token. + w.Header().Set("WWW-Authenticate", `Bearer realm="test", error="invalid_token"`) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token","error_description":"Invalid access token"}`)) + }) + + return srv, &tokenCalls +} + +// newTransportWithStaleToken builds an oauthTransport pre-populated with a +// non-expired stale token "old-at" that the server will reject with +// invalid_token. The token carries a refresh_token and client credentials so +// the silent refresh path can be exercised. +func newTransportWithStaleToken(t *testing.T, baseURL string) *oauthTransport { + t.Helper() + store := NewInMemoryTokenStore() + err := store.StoreToken(baseURL, &OAuthToken{ + AccessToken: "old-at", + TokenType: "Bearer", + RefreshToken: "old-rt", + // Non-expired so getValidToken returns it without local-expiry refresh. + ExpiresAt: time.Now().Add(1 * time.Hour), + ClientID: "cid", + ClientSecret: "csec", + AuthServer: baseURL, // metadata discovery goes to the same server + }) + require.NoError(t, err) + return &oauthTransport{ + base: http.DefaultTransport, + client: &remoteMCPClient{}, + tokenStore: store, + baseURL: baseURL, + // Allow private IPs so the httptest 127.0.0.1 server is reachable. + oauthHTTPClient: oauthHTTPClientForAllowPrivateIPs(true), + } +} + +// TestRoundTrip_ServerInvalidTokenEvictsAndRefreshes is the primary regression +// test for #3198: when the server rejects a non-expired stored token with +// invalid_token, the transport must silently evict the stale token, call the +// token endpoint exactly once to refresh it, and replay the request with the +// new bearer so the caller sees a 200 response. +func TestRoundTrip_ServerInvalidTokenEvictsAndRefreshes(t *testing.T) { + srv, tokenCalls := newInvalidTokenTestServer(t) + + var oauthSuccessFired atomic.Bool + transport := newTransportWithStaleToken(t, srv.URL) + transport.client.SetOAuthSuccessHandler(func() { oauthSuccessFired.Store(true) }) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, srv.URL, strings.NewReader("{}")) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), tokenCalls.Load(), "token endpoint must be called exactly once") + + // The stored token must now be the freshly-obtained one. + stored, err := transport.tokenStore.GetToken(srv.URL) + require.NoError(t, err) + assert.Equal(t, "fresh-at", stored.AccessToken) + + assert.True(t, oauthSuccessFired.Load(), "oauthSuccess must be called after silent refresh") +} + +// TestRoundTrip_ServerInvalidToken_RefreshFails_DefersWhenNonInteractive +// verifies that when the server rejects the token with invalid_token but the +// token-endpoint refresh fails, a non-interactive context returns +// IsAuthorizationRequired and the stale token is evicted from the store. +func TestRoundTrip_ServerInvalidToken_RefreshFails_DefersWhenNonInteractive(t *testing.T) { + var tokenCalls atomic.Int32 + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(AuthorizationServerMetadata{ + Issuer: srv.URL, + AuthorizationEndpoint: srv.URL + "/authorize", + TokenEndpoint: srv.URL + "/token", + }) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) { + tokenCalls.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("WWW-Authenticate", `Bearer realm="test", error="invalid_token"`) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + }) + + transport := newTransportWithStaleToken(t, srv.URL) + + ctx := WithoutInteractivePrompts(t.Context()) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, srv.URL, strings.NewReader("{}")) + require.NoError(t, err) + + resp, rtErr := transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + require.Error(t, rtErr) + assert.True(t, IsAuthorizationRequired(rtErr), "must return AuthorizationRequiredError; got: %v", rtErr) + + // Stale token must have been evicted. + _, storeErr := transport.tokenStore.GetToken(srv.URL) + require.Error(t, storeErr, "stale token must be evicted from the store") + + // The token endpoint was hit at most once (backoff not yet active). + assert.Equal(t, int32(1), tokenCalls.Load()) + + // A second request within the backoff window must NOT hit the token + // endpoint again (refreshFailedAt is set). + req2, _ := http.NewRequestWithContext(ctx, http.MethodPost, srv.URL, strings.NewReader("{}")) + // Re-populate the store with a stale token so roundTrip tries to attach it again. + _ = transport.tokenStore.StoreToken(srv.URL, &OAuthToken{ + AccessToken: "old-at", + RefreshToken: "old-rt", + ExpiresAt: time.Now().Add(time.Hour), + ClientID: "cid", + ClientSecret: "csec", + AuthServer: srv.URL, + }) + resp2, _ := transport.RoundTrip(req2) + if resp2 != nil { + resp2.Body.Close() + } + assert.Equal(t, int32(1), tokenCalls.Load(), "backoff must prevent a second token-endpoint hit") +} + +// TestRoundTrip_ServerInvalidToken_NoRefreshToken_NonInteractive verifies +// that when the stored token has no refresh_token, a non-interactive context +// returns IsAuthorizationRequired without hitting the token endpoint. +func TestRoundTrip_ServerInvalidToken_NoRefreshToken_NonInteractive(t *testing.T) { + mux := http.NewServeMux() + var tokenCalls atomic.Int32 + + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("WWW-Authenticate", `Bearer realm="test", error="invalid_token"`) + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) { + tokenCalls.Add(1) + w.WriteHeader(http.StatusBadRequest) + }) + + store := NewInMemoryTokenStore() + err := store.StoreToken(srv.URL, &OAuthToken{ + AccessToken: "old-at", + TokenType: "Bearer", + // No RefreshToken → refresh path must not be attempted. + ExpiresAt: time.Now().Add(time.Hour), + AuthServer: srv.URL, + }) + require.NoError(t, err) + + transport := &oauthTransport{ + base: http.DefaultTransport, + client: &remoteMCPClient{}, + tokenStore: store, + baseURL: srv.URL, + oauthHTTPClient: oauthHTTPClientForAllowPrivateIPs(true), + } + + ctx := WithoutInteractivePrompts(t.Context()) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, srv.URL, strings.NewReader("{}")) + resp, rtErr := transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + require.Error(t, rtErr) + assert.True(t, IsAuthorizationRequired(rtErr)) + assert.Equal(t, int32(0), tokenCalls.Load(), "no refresh token → token endpoint must never be hit") +} + +// TestRoundTrip_FirstContact401Unchanged is a regression guard: when no +// token is stored (first-contact), the transport must take the existing +// interactive OAuth flow unchanged — it must NOT treat the first-contact +// 401 as an invalid_token rejection. +func TestRoundTrip_FirstContact401Unchanged(t *testing.T) { + srv := newUnmanagedOAuthTestServer(t) + + var elicitCalls atomic.Int32 + capture := &elicitCaptured{} + capture.replyFn = func(_ *gomcp.ElicitParams) tools.ElicitationResult { + elicitCalls.Add(1) + return tools.ElicitationResult{ + Action: tools.ElicitationActionAccept, + Content: map[string]any{ + "access_token": "first-at", + "token_type": "Bearer", + }, + } + } + transport, _ := newUnmanagedTestTransport(t, srv.URL, "", capture) + // No pre-stored token → first-contact path. + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, srv.URL, strings.NewReader("{}")) + require.NoError(t, err) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), elicitCalls.Load(), "first-contact must go through the interactive OAuth flow") + // Token endpoint was NOT hit for first-contact (client supplied the token directly). + assert.Equal(t, int32(0), srv.tokenCalls.Load()) +} + +// TestRoundTrip_ConcurrentInvalidToken_RefreshesOnce verifies the +// coalescing behaviour: when N concurrent roundTrips all see the stale +// token and hit 401 + invalid_token simultaneously, exactly one refresh +// is issued and all N requests eventually succeed with the fresh token. +func TestRoundTrip_ConcurrentInvalidToken_RefreshesOnce(t *testing.T) { + srv, tokenCalls := newInvalidTokenTestServer(t) + + transport := newTransportWithStaleToken(t, srv.URL) + transport.client.SetOAuthSuccessHandler(func() {}) + + const n = 6 + results := make(chan error, n) + + for range n { + go func() { + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, srv.URL, strings.NewReader("{}")) + if err != nil { + results <- err + return + } + resp, err := transport.RoundTrip(req) + if resp != nil { + _ = resp.Body.Close() + } + results <- err + }() + } + + for range n { + select { + case err := <-results: + require.NoError(t, err, "all concurrent requests must eventually succeed") + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for concurrent round-trips") + } + } + + assert.Equal(t, int32(1), tokenCalls.Load(), + "oauthFlowMu coalescing must ensure the token endpoint is hit exactly once") +} + +// TestRoundTrip_Bare401WithAttachedTokenUnchanged is the regression guard for +// blocking issue 1: when a token WAS attached to the request and the server +// returns a plain 401 (no error="invalid_token" in the WWW-Authenticate +// header and no invalid_token in the body), the transport MUST return the 401 +// response as-is — no eviction, no /token call, stored credential untouched. +// +// This covers the "app-level authorization failure" case (wrong permissions, +// revoked app access, etc.) where the server rejects the bearer for a reason +// that isn't a stale/revoked token we can silently refresh. +func TestRoundTrip_Bare401WithAttachedTokenUnchanged(t *testing.T) { + var tokenCalls atomic.Int32 + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + // Main endpoint: 401 with Bearer realm but NO error= code. + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Bearer realm="test"`) + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message":"forbidden"}`)) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) { + tokenCalls.Add(1) + w.WriteHeader(http.StatusBadRequest) + }) + + store := NewInMemoryTokenStore() + require.NoError(t, store.StoreToken(srv.URL, &OAuthToken{ + AccessToken: "good-at", + TokenType: "Bearer", + RefreshToken: "good-rt", // has a refresh token — must NOT be used + ExpiresAt: time.Now().Add(time.Hour), + AuthServer: srv.URL, + ClientID: "cid", + ClientSecret: "csec", + })) + + transport := &oauthTransport{ + base: http.DefaultTransport, + client: &remoteMCPClient{}, + tokenStore: store, + baseURL: srv.URL, + oauthHTTPClient: oauthHTTPClientForAllowPrivateIPs(true), + } + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, srv.URL, strings.NewReader("{}")) + require.NoError(t, err) + + resp, rtErr := transport.RoundTrip(req) + require.NoError(t, rtErr, "bare 401 must be returned as a normal response, not an error") + if resp != nil { + resp.Body.Close() + } + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "bare 401 without invalid_token must be returned unchanged") + assert.Equal(t, int32(0), tokenCalls.Load(), + "token endpoint must NOT be called for a bare 401 without invalid_token") + + // Stored token must be untouched. + stored, storeErr := transport.tokenStore.GetToken(srv.URL) + require.NoError(t, storeErr, "token must still be in the store") + assert.Equal(t, "good-at", stored.AccessToken, + "stored token must not be evicted or replaced on a bare 401") +} + +// TestRoundTrip_ConcurrentInvalidToken_NoRefresh_StickyDecline is the +// regression guard for blocking issue 2: when concurrent requests all hit a +// 401 + error="invalid_token" for a token that has NO refresh_token, the +// transport must fall back to interactive OAuth. The first goroutine runs the +// OAuth flow; the user declines. The sticky-decline latch (lastOAuthDeclined) +// must be observed by all subsequently-queued goroutines so that exactly ONE +// elicitation dialog is popped and all callers surface OAuthDeclinedError. +func TestRoundTrip_ConcurrentInvalidToken_NoRefresh_StickyDecline(t *testing.T) { + var tokenCalls atomic.Int32 + mux := http.NewServeMux() + srv := httptest.NewUnstartedServer(mux) + srv.Start() + t.Cleanup(srv.Close) + + // OAuth discovery: needed by handleUnmanagedOAuthFlow before it sends the + // elicitation. Both endpoints must succeed for the flow to reach elicitation. + mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "resource": srv.URL, + "authorization_servers": []string{srv.URL}, + }) + }) + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(AuthorizationServerMetadata{ + Issuer: srv.URL, + AuthorizationEndpoint: srv.URL + "/authorize", + TokenEndpoint: srv.URL + "/token", + }) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) { + tokenCalls.Add(1) + w.WriteHeader(http.StatusBadRequest) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("WWW-Authenticate", `Bearer realm="test", error="invalid_token"`) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + }) + + // Token without a refresh_token: no silent refresh, must fall back to + // interactive OAuth. + store := NewInMemoryTokenStore() + require.NoError(t, store.StoreToken(srv.URL, &OAuthToken{ + AccessToken: "old-at", + TokenType: "Bearer", + // No RefreshToken. + ExpiresAt: time.Now().Add(time.Hour), + AuthServer: srv.URL, + })) + + var elicitCount atomic.Int32 + capture := &elicitCaptured{} + capture.replyFn = func(_ *gomcp.ElicitParams) tools.ElicitationResult { + elicitCount.Add(1) + return tools.ElicitationResult{Action: tools.ElicitationActionDecline} + } + + client := newRemoteClient(srv.URL, "streamable", nil, store, nil, false) + client.SetElicitationHandler(capture.handler) + client.allowPrivateIPs = true + + transport := &oauthTransport{ + base: http.DefaultTransport, + client: client, + tokenStore: store, + baseURL: srv.URL, + managed: false, + oauthHTTPClient: oauthHTTPClientForAllowPrivateIPs(true), + } + + const n = 5 + results := make(chan error, n) + for range n { + go func() { + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, srv.URL, strings.NewReader("{}")) + if err != nil { + results <- err + return + } + resp, err := transport.RoundTrip(req) + if resp != nil { + _ = resp.Body.Close() + } + results <- err + }() + } + + for range n { + select { + case err := <-results: + require.Error(t, err) + assert.True(t, IsOAuthDeclined(err), + "all concurrent requests must surface OAuthDeclinedError after the user declined; got: %v", err) + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for concurrent round-trips to complete") + } + } + + assert.Equal(t, int32(1), elicitCount.Load(), + "sticky-decline must ensure exactly one elicitation dialog is shown") + assert.Equal(t, int32(0), tokenCalls.Load(), + "token endpoint must NOT be hit: no refresh token exists") +} diff --git a/pkg/tools/startable.go b/pkg/tools/startable.go index b23b5a6bb..a819b63ba 100644 --- a/pkg/tools/startable.go +++ b/pkg/tools/startable.go @@ -89,6 +89,12 @@ type StartableToolSet struct { started bool startStreak failureStreak // Start() failures listStreak failureStreak // Tools() listing failures + // recoveryStreak tracks once-per-streak notices specifically for + // recovery failures (the toolset was previously started and working, + // then Start failed again). Distinct from startStreak so callers can + // emit a different, more targeted message (e.g. "needs re-auth" vs + // "start failed") for the recovery case. + recoveryStreak failureStreak } // NewStartable wraps a ToolSet for lazy initialization. @@ -152,6 +158,7 @@ func (s *StartableToolSet) Start(ctx context.Context) (err error) { }() if err := restarter.Restart(ctx); err != nil { s.startStreak.fail() + s.recoveryStreak.fail() return err } } else if startable, ok := As[Startable](s.ToolSet); ok { @@ -178,6 +185,7 @@ func (s *StartableToolSet) Start(ctx context.Context) (err error) { // as fresh. This is the recovery path — it is intentionally silent. s.started = true s.startStreak.reset() + s.recoveryStreak.reset() return nil } @@ -208,6 +216,7 @@ func (s *StartableToolSet) Stop(ctx context.Context) error { s.started = false s.startStreak.reset() s.listStreak.reset() + s.recoveryStreak.reset() if startable, ok := As[Startable](s.ToolSet); ok { return startable.Stop(ctx) } @@ -234,6 +243,22 @@ func (s *StartableToolSet) ShouldReportListFailure() bool { return s.listStreak.shouldReport() } +// ShouldReportRecoveryFailure returns true exactly once per recovery-failure +// streak — when a toolset that was previously started and working fails to +// restart (e.g. because the server revoked the OAuth token in the background). +// +// Unlike ShouldReportFailure (which fires for both initial and recovery +// failures), this method fires only for recovery failures so callers can +// emit a targeted "needs re-authentication" notice instead of a generic +// "start failed" one. Returns false for initial-startup auth deferral +// (those are silent pending prompts and the dialog appears naturally on +// the first interactive turn). +func (s *StartableToolSet) ShouldReportRecoveryFailure() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.recoveryStreak.shouldReport() +} + // Unwrap returns the underlying ToolSet. func (s *StartableToolSet) Unwrap() ToolSet { return s.ToolSet diff --git a/pkg/tools/startable_test.go b/pkg/tools/startable_test.go index 301562a3b..de4dfb3ce 100644 --- a/pkg/tools/startable_test.go +++ b/pkg/tools/startable_test.go @@ -321,3 +321,139 @@ func TestStartableToolSet_NoStartReporterPreservesLatchedStart(t *testing.T) { assert.NilError(t, s.Start(t.Context())) assert.Check(t, is.Equal(inner.startups, 1)) } + +// recoveryFailingToolSet simulates a toolset that starts successfully on +// the first attempt (Start) and then fails on every Restart call, +// representing a toolset that was working but became unavailable. +type recoveryFailingToolSet struct { + started bool + restartErr error +} + +func (r *recoveryFailingToolSet) Tools(context.Context) ([]tools.Tool, error) { return nil, nil } +func (r *recoveryFailingToolSet) IsStarted() bool { return r.started } +func (r *recoveryFailingToolSet) Start(context.Context) error { + r.started = true + return nil +} +func (r *recoveryFailingToolSet) Restart(_ context.Context) error { return r.restartErr } +func (r *recoveryFailingToolSet) Stop(_ context.Context) error { + r.started = false + return nil +} + +// TestStartableToolSet_ShouldReportRecoveryFailure_OncePerStreak verifies +// that ShouldReportRecoveryFailure returns true exactly once when a +// previously-started toolset fails to recover (recovering=true path), and +// is silent for subsequent calls in the same streak. +func TestStartableToolSet_ShouldReportRecoveryFailure_OncePerStreak(t *testing.T) { + t.Parallel() + + authErr := errors.New("authentication required") + inner := &recoveryFailingToolSet{restartErr: authErr} + s := tools.NewStartable(inner) + + // First Start: succeeds and marks the toolset as started. + assert.NilError(t, s.Start(t.Context())) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), false), "no recovery failure yet") + + // Simulate the inner toolset going down (e.g. background reconnect failed). + inner.started = false + + // Recovery attempt 1: Restart fails → streak begins. + assert.Check(t, s.Start(t.Context()) != nil, "expected error on recovery") + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), true), "first recovery failure must be reported") + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), false), "second call in same streak must be false") +} + +// TestStartableToolSet_ShouldReportRecoveryFailure_NotFiredForInitialStartup +// verifies that ShouldReportRecoveryFailure is NOT triggered for initial- +// startup failures (toolset was never started before). Only recovery +// failures (toolset was working, then failed) should trigger the notice. +func TestStartableToolSet_ShouldReportRecoveryFailure_NotFiredForInitialStartup(t *testing.T) { + t.Parallel() + + errBoom := errors.New("startup error") + f := &flappyToolSet{errs: []error{errBoom, errBoom}} + s := tools.NewStartable(f) + + // Turn 1: initial startup failure (never started before). + assert.Check(t, s.Start(t.Context()) != nil) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), false), + "initial-startup failure must NOT trigger recovery notice") + + // Turn 2: second startup failure. + assert.Check(t, s.Start(t.Context()) != nil) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), false), + "repeated initial-startup failure must NOT trigger recovery notice") +} + +// TestStartableToolSet_ShouldReportRecoveryFailure_ResetsOnSuccess verifies +// that a successful recovery clears the streak so a future failure is +// reported as fresh. +func TestStartableToolSet_ShouldReportRecoveryFailure_ResetsOnSuccess(t *testing.T) { + t.Parallel() + + authErr := errors.New("authentication required") + inner := &recoveryFailingToolSet{restartErr: authErr} + s := tools.NewStartable(inner) + + // Initial start succeeds (Start always returns nil for recoveryFailingToolSet). + assert.NilError(t, s.Start(t.Context())) + assert.Check(t, is.Equal(s.IsStarted(), true)) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), false), "no recovery failure yet") + + // Background failure: inner loses its connection. + inner.started = false + + // Recovery fails: Restart returns authErr. + err := s.Start(t.Context()) + assert.Check(t, err != nil, "expected error on recovery failure") + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), true), "first recovery failure must be reported") + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), false), "second call in same streak must return false (dedup)") + + // Successful recovery: clear the error so the next Start goes through and + // resets the streak. Because s.started==false after the failed Restart, + // Start takes the non-recovery path (inner.Start), which succeeds. + inner.restartErr = nil + assert.NilError(t, s.Start(t.Context()), "recovery with nil restartErr must succeed") + assert.Check(t, is.Equal(s.IsStarted(), true)) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), false), + "after successful recovery, the streak must be reset") + + // A subsequent background failure after the reset is a fresh streak. + inner.restartErr = authErr + inner.started = false + _ = s.Start(t.Context()) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), true), + "fresh failure after streak reset must be reported") +} + +// TestStartableToolSet_ShouldReportRecoveryFailure_ResetsOnStop verifies +// that Stop clears the recovery streak. +func TestStartableToolSet_ShouldReportRecoveryFailure_ResetsOnStop(t *testing.T) { + t.Parallel() + + authErr := errors.New("authentication required") + inner := &recoveryFailingToolSet{restartErr: authErr} + s := tools.NewStartable(inner) + + // Initial start → recovery failure → consume the once-report. + assert.NilError(t, s.Start(t.Context())) + inner.started = false + assert.Check(t, s.Start(t.Context()) != nil) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), true), "must report once") + + // Stop resets all streaks. + assert.NilError(t, s.Stop(t.Context())) + + // A new recovery cycle after Stop must report again. + inner.started = false // inner Stop set it false, but simulate inner starting first + inner.restartErr = nil + assert.NilError(t, s.Start(t.Context())) // inner Start succeeds (restartErr cleared) + inner.started = false + inner.restartErr = authErr + + assert.Check(t, s.Start(t.Context()) != nil) + assert.Check(t, is.Equal(s.ShouldReportRecoveryFailure(), true), "fresh recovery after Stop must report again") +}