diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index 3c29fe03d30..0dbba27fbd4 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -359,8 +359,13 @@ export namespace MCP { } catch (error) { lastError = error instanceof Error ? error : new Error(String(error)) - // Handle OAuth-specific errors - if (error instanceof UnauthorizedError) { + // Handle OAuth-specific errors. + // The SDK throws UnauthorizedError when auth() returns 'REDIRECT', + // but may also throw plain Errors when auth() fails internally + // (e.g. during discovery, registration, or state generation). + // When an authProvider is attached, treat both cases as auth-related. + const isAuthError = error instanceof UnauthorizedError || (authProvider && lastError.message.includes("OAuth")) + if (isAuthError) { log.info("mcp server requires authentication", { key, transport: name }) // Check if this is a "needs registration" error diff --git a/packages/opencode/src/mcp/oauth-provider.ts b/packages/opencode/src/mcp/oauth-provider.ts index 164b1d1f143..b4da73169e1 100644 --- a/packages/opencode/src/mcp/oauth-provider.ts +++ b/packages/opencode/src/mcp/oauth-provider.ts @@ -144,10 +144,19 @@ export class McpOAuthProvider implements OAuthClientProvider { async state(): Promise { const entry = await McpAuth.get(this.mcpName) - if (!entry?.oauthState) { - throw new Error(`No OAuth state saved for MCP server: ${this.mcpName}`) + if (entry?.oauthState) { + return entry.oauthState } - return entry.oauthState + + // Generate a new state if none exists — the SDK calls state() as a + // generator, not just a reader, so we need to produce a value even when + // startAuth() hasn't pre-saved one (e.g. during automatic auth on first + // connect). + const newState = Array.from(crypto.getRandomValues(new Uint8Array(32))) + .map((b) => b.toString(16).padStart(2, "0")) + .join("") + await McpAuth.updateOAuthState(this.mcpName, newState) + return newState } async invalidateCredentials(type: "all" | "client" | "tokens"): Promise { diff --git a/packages/opencode/test/mcp/oauth-auto-connect.test.ts b/packages/opencode/test/mcp/oauth-auto-connect.test.ts new file mode 100644 index 00000000000..0cd5c36e527 --- /dev/null +++ b/packages/opencode/test/mcp/oauth-auto-connect.test.ts @@ -0,0 +1,197 @@ +import { test, expect, mock, beforeEach } from "bun:test" + +// Mock UnauthorizedError to match the SDK's class +class MockUnauthorizedError extends Error { + constructor(message?: string) { + super(message ?? "Unauthorized") + this.name = "UnauthorizedError" + } +} + +// Track what options were passed to each transport constructor +const transportCalls: Array<{ + type: "streamable" | "sse" + url: string + options: { authProvider?: unknown } +}> = [] + +// Controls whether the mock transport simulates a 401 that triggers the SDK +// auth flow (which calls provider.state()) or a simple UnauthorizedError. +let simulateAuthFlow = true + +// Mock the transport constructors to simulate OAuth auto-auth on 401 +mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ + StreamableHTTPClientTransport: class MockStreamableHTTP { + authProvider: { + state?: () => Promise + redirectToAuthorization?: (url: URL) => Promise + saveCodeVerifier?: (v: string) => Promise + } | undefined + constructor(url: URL, options?: { authProvider?: unknown }) { + this.authProvider = options?.authProvider as typeof this.authProvider + transportCalls.push({ + type: "streamable", + url: url.toString(), + options: options ?? {}, + }) + } + async start() { + // Simulate what the real SDK transport does on 401: + // It calls auth() which eventually calls provider.state(), then + // provider.redirectToAuthorization(), then throws UnauthorizedError. + if (simulateAuthFlow && this.authProvider) { + // The SDK calls provider.state() to get the OAuth state parameter + if (this.authProvider.state) { + await this.authProvider.state() + } + // The SDK calls saveCodeVerifier before redirecting + if (this.authProvider.saveCodeVerifier) { + await this.authProvider.saveCodeVerifier("test-verifier") + } + // The SDK calls redirectToAuthorization to redirect the user + if (this.authProvider.redirectToAuthorization) { + await this.authProvider.redirectToAuthorization(new URL("https://auth.example.com/authorize?state=test")) + } + throw new MockUnauthorizedError() + } + throw new MockUnauthorizedError() + } + async finishAuth(_code: string) {} + }, +})) + +mock.module("@modelcontextprotocol/sdk/client/sse.js", () => ({ + SSEClientTransport: class MockSSE { + constructor(url: URL, options?: { authProvider?: unknown }) { + transportCalls.push({ + type: "sse", + url: url.toString(), + options: options ?? {}, + }) + } + async start() { + throw new Error("Mock SSE transport cannot connect") + } + }, +})) + +// Mock the MCP SDK Client +mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({ + Client: class MockClient { + async connect(transport: { start: () => Promise }) { + await transport.start() + } + }, +})) + +// Mock UnauthorizedError in the auth module so instanceof checks work +mock.module("@modelcontextprotocol/sdk/client/auth.js", () => ({ + UnauthorizedError: MockUnauthorizedError, +})) + +beforeEach(() => { + transportCalls.length = 0 + simulateAuthFlow = true +}) + +// Import modules after mocking +const { MCP } = await import("../../src/mcp/index") +const { Instance } = await import("../../src/project/instance") +const { tmpdir } = await import("../fixture/fixture") + +test("first connect to OAuth server shows needs_auth instead of failed", async () => { + await using tmp = await tmpdir({ + init: async (dir) => { + await Bun.write( + `${dir}/opencode.json`, + JSON.stringify({ + $schema: "https://opencode.ai/config.json", + mcp: { + "test-oauth": { + type: "remote", + url: "https://example.com/mcp", + }, + }, + }), + ) + }, + }) + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const result = await MCP.add("test-oauth", { + type: "remote", + url: "https://example.com/mcp", + }) + + const serverStatus = result.status as Record + + // The server should be detected as needing auth, NOT as failed. + // Before the fix, provider.state() would throw a plain Error + // ("No OAuth state saved for MCP server: test-oauth") which was + // not caught as UnauthorizedError, causing status to be "failed". + expect(serverStatus["test-oauth"]).toBeDefined() + expect(serverStatus["test-oauth"].status).toBe("needs_auth") + }, + }) +}) + +test("state() generates a new state when none is saved", async () => { + const { McpOAuthProvider } = await import("../../src/mcp/oauth-provider") + const { McpAuth } = await import("../../src/mcp/auth") + + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const provider = new McpOAuthProvider( + "test-state-gen", + "https://example.com/mcp", + {}, + { onRedirect: async () => {} }, + ) + + // Ensure no state exists + const entryBefore = await McpAuth.get("test-state-gen") + expect(entryBefore?.oauthState).toBeUndefined() + + // state() should generate and return a new state, not throw + const state = await provider.state() + expect(typeof state).toBe("string") + expect(state.length).toBe(64) // 32 bytes as hex + + // The generated state should be persisted + const entryAfter = await McpAuth.get("test-state-gen") + expect(entryAfter?.oauthState).toBe(state) + }, + }) +}) + +test("state() returns existing state when one is saved", async () => { + const { McpOAuthProvider } = await import("../../src/mcp/oauth-provider") + const { McpAuth } = await import("../../src/mcp/auth") + + await using tmp = await tmpdir() + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const provider = new McpOAuthProvider( + "test-state-existing", + "https://example.com/mcp", + {}, + { onRedirect: async () => {} }, + ) + + // Pre-save a state + const existingState = "pre-saved-state-value" + await McpAuth.updateOAuthState("test-state-existing", existingState) + + // state() should return the existing state + const state = await provider.state() + expect(state).toBe(existingState) + }, + }) +})