diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 9a81358b..ac51ea12 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -101,6 +101,19 @@ type AuthorizationCodeHandlerConfig struct { // It should return the authorization code and state once the Authorization Server // redirects back to the RedirectURL. AuthorizationCodeFetcher func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) + + // Client is an optional HTTP client to use for HTTP requests. + // It is used for the following requests: + // - Fetching Protected Resource Metadata + // - Fetching Authorization Server Metadata + // - Registering a client dynamically + // - Exchanging an authorization code for an access token + // - Refreshing an access token + // Custom clients can include additional security configurations, + // such as SSRF protections, see + // https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#server-side-request-forgery-ssrf + // If not provided, http.DefaultClient will be used. + Client *http.Client } // AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses @@ -166,6 +179,9 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho // it should have been set by now. Otherwise, it is required. return nil, errors.New("RedirectURL is required") } + if config.Client == nil { + config.Client = http.DefaultClient + } return &AuthorizationCodeHandler{config: config}, nil } @@ -280,7 +296,7 @@ func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Cont // Use MCP server URL as the resource URI per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. for _, url := range protectedResourceMetadataURLs(resourceMetadataURLFromChallenges(wwwChallenges), mcpServerURL) { - prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, http.DefaultClient) + prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, h.config.Client) if err != nil { errs = append(errs, err) continue @@ -359,7 +375,7 @@ func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL { func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata) (*oauthex.AuthServerMeta, error) { authServerURL := prm.AuthorizationServers[0] for _, u := range authorizationServerMetadataURLs(authServerURL) { - asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient) + asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, h.config.Client) if err != nil { return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) } @@ -488,7 +504,7 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * // 3. Attempt to use dynamic client registration. dcrCfg := h.config.DynamicClientRegistrationConfig if dcrCfg != nil && asm.RegistrationEndpoint != "" { - regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient) + regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, h.config.Client) if err != nil { return nil, fmt.Errorf("failed to register client: %w", err) } @@ -542,10 +558,11 @@ func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context oauth2.VerifierOption(authResult.usedCodeVerifier), oauth2.SetAuthURLParam("resource", resourceURL), } - token, err := cfg.Exchange(ctx, authResult.Code, opts...) + clientCtx := context.WithValue(ctx, oauth2.HTTPClient, h.config.Client) + token, err := cfg.Exchange(clientCtx, authResult.Code, opts...) if err != nil { return fmt.Errorf("token exchange failed: %w", err) } - h.tokenSource = cfg.TokenSource(ctx, token) + h.tokenSource = cfg.TokenSource(clientCtx, token) return nil } diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index c00e7642..d371cba9 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -125,8 +125,11 @@ func TestAuthorize_ForbiddenUnhandledError(t *testing.T) { "WWW-Authenticate", "Bearer error=invalid_token", ) - handler := &AuthorizationCodeHandler{} // No config needed for this test. - err := handler.Authorize(t.Context(), req, resp) + handler, err := NewAuthorizationCodeHandler(validConfig()) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + err = handler.Authorize(t.Context(), req, resp) if err != nil { t.Fatalf("Authorize() failed: %v", err) } @@ -200,16 +203,7 @@ func TestNewAuthorizationCodeHandler_Success(t *testing.T) { } func TestNewAuthorizationCodeHandler_Error(t *testing.T) { - validConfig := func() *AuthorizationCodeHandlerConfig { - return &AuthorizationCodeHandlerConfig{ - ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, - RedirectURL: "https://example.com/callback", - AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { - return nil, nil - }, - } - } - // Ensure the base config is valid + // Ensure the base config is valid. if _, err := NewAuthorizationCodeHandler(validConfig()); err != nil { t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) } @@ -324,7 +318,11 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) { } func TestGetProtectedResourceMetadata_Success(t *testing.T) { - handler := &AuthorizationCodeHandler{} // No config needed for this method + handler, err := NewAuthorizationCodeHandler(validConfig()) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler() error = %v", err) + } + pathForChallenge := "/protected-resource" tests := []struct { @@ -398,8 +396,11 @@ func TestGetProtectedResourceMetadata_Success(t *testing.T) { } func TestGetProtectedResourceMetadata_Backcompat(t *testing.T) { + handler, err := NewAuthorizationCodeHandler(validConfig()) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler() error = %v", err) + } var challenges []oauthex.Challenge - handler := &AuthorizationCodeHandler{} // No config needed for this method got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, "http://localhost:1234/resource") if err != nil { t.Fatalf("getProtectedResourceMetadata() error = %v", err) @@ -423,8 +424,11 @@ func TestGetProtectedResourceMetadata_Error(t *testing.T) { ScopesSupported: []string{"read", "write"}, } mux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(metadata)) + handler, err := NewAuthorizationCodeHandler(validConfig()) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler() error = %v", err) + } var challenges []oauthex.Challenge - handler := &AuthorizationCodeHandler{} // No config needed for this method got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, server.URL+"/resource") if err == nil || !strings.Contains(err.Error(), "authorization servers") { t.Errorf("getProtectedResourceMetadata() = %v, want error containing \"authorization servers\"", err) @@ -435,7 +439,10 @@ func TestGetProtectedResourceMetadata_Error(t *testing.T) { } func TestGetAuthServerMetadata(t *testing.T) { - handler := &AuthorizationCodeHandler{} // No config needed for this method + handler, err := NewAuthorizationCodeHandler(validConfig()) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler() error = %v", err) + } tests := []struct { name string @@ -563,11 +570,11 @@ func TestHandleRegistration(t *testing.T) { ClientIDMetadataDocumentSupported: true, }, handlerConfig: &AuthorizationCodeHandlerConfig{ - ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"}, + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com/metadata.json"}, }, want: &resolvedClientConfig{ registrationType: registrationTypeClientIDMetadataDocument, - clientID: "https://client.example.com", + clientID: "https://client.example.com/metadata.json", }, }, { @@ -597,7 +604,7 @@ func TestHandleRegistration(t *testing.T) { { name: "NoneSupported", handlerConfig: &AuthorizationCodeHandlerConfig{ - ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"}, + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com/metadata.json"}, }, wantError: true, }, @@ -607,7 +614,14 @@ func TestHandleRegistration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{RegistrationConfig: tt.serverConfig}) s.Start(t) - handler := &AuthorizationCodeHandler{config: tt.handlerConfig} + tt.handlerConfig.AuthorizationCodeFetcher = func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + return nil, nil + } + tt.handlerConfig.RedirectURL = "https://example.com/callback" + handler, err := NewAuthorizationCodeHandler(tt.handlerConfig) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler() error = %v, want nil", err) + } asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ AuthorizationServers: []string{s.URL()}, }) @@ -644,11 +658,20 @@ func TestDynamicRegistration(t *testing.T) { }, }) s.Start(t) - handler := &AuthorizationCodeHandler{config: &AuthorizationCodeHandlerConfig{ + handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{ DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{ - Metadata: &oauthex.ClientRegistrationMetadata{}, + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + }, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + return nil, nil }, - }} + }) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler() error = %v", err) + } asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ AuthorizationServers: []string{s.URL()}, }) @@ -672,3 +695,15 @@ func TestDynamicRegistration(t *testing.T) { t.Errorf("handleRegistration() authStyle = %v, want %v", got.authStyle, oauth2.AuthStyleInHeader) } } + +// validConfig for test to create an AuthorizationCodeHandler using its constructor. +// Values that are relevant to the test should be set explicitly. +func validConfig() *AuthorizationCodeHandlerConfig { + return &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + return nil, nil + }, + } +} diff --git a/docs/protocol.md b/docs/protocol.md index c4f1cefc..aeff5e97 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -13,6 +13,7 @@ 1. [Security](#security) 1. [Confused Deputy](#confused-deputy) 1. [Token Passthrough](#token-passthrough) + 1. [Server-Side Request Forgery (SSRF)](#server-side-request-forgery-(ssrf)) 1. [Session Hijacking](#session-hijacking) 1. [Utilities](#utilities) 1. [Cancellation](#cancellation) @@ -363,7 +364,7 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, is mostly the responsibility of the MCP Proxy server implementation. The SDK client does generate cryptographically secure random `state` values for each authorization @@ -372,15 +373,36 @@ Mismatched state values will result in an error. ### Token Passthrough -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure +The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure of tokens and is the responsibility of the [`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) provided to [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). +### Server-Side Request Forgery (SSRF) + +The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-3) are as follows: + +- _Enforce HTTPS_. The OAuth helpers provided by the SDK reject the `http://` URLs +except loopback addresses (`localhost`, `127.0.0.1`, `::1`). + +- _Block Private IP Ranges_. The OAuth helpers provided by the SDK allow passing +a custom `http.Client`. Developers are advised to customize the client it with +appropriate network protections, including IP range blocking. The SDK does not provide +this capability out of the box. + +- _Validate Redirect Targets_. Similarly to previous point, customized `http.Client` +can be used to validate network hops. The SDK does not provide this capability out +of the box. + +- _Use Egress Proxies_. This is out of scope for the SDK and can be configured separately. + +- _DNS Resolution Considerations_. The SDK has DNS rebinding protection on the server side which is enabled by default. For the client side, consider providing +a custom `http.Client` that would implement DNS pinning. + ### Session Hijacking -The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows: +The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-4) are as follows: - _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken) middleware function will verify all HTTP requests that it receives. It is the diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 3771b581..5134a181 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -289,7 +289,7 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, is mostly the responsibility of the MCP Proxy server implementation. The SDK client does generate cryptographically secure random `state` values for each authorization @@ -298,15 +298,36 @@ Mismatched state values will result in an error. ### Token Passthrough -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure +The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure of tokens and is the responsibility of the [`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) provided to [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). +### Server-Side Request Forgery (SSRF) + +The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-3) are as follows: + +- _Enforce HTTPS_. The OAuth helpers provided by the SDK reject the `http://` URLs +except loopback addresses (`localhost`, `127.0.0.1`, `::1`). + +- _Block Private IP Ranges_. The OAuth helpers provided by the SDK allow passing +a custom `http.Client`. Developers are advised to customize the client it with +appropriate network protections, including IP range blocking. The SDK does not provide +this capability out of the box. + +- _Validate Redirect Targets_. Similarly to previous point, customized `http.Client` +can be used to validate network hops. The SDK does not provide this capability out +of the box. + +- _Use Egress Proxies_. This is out of scope for the SDK and can be configured separately. + +- _DNS Resolution Considerations_. The SDK has DNS rebinding protection on the server side which is enabled by default. For the client side, consider providing +a custom `http.Client` that would implement DNS pinning. + ### Session Hijacking -The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows: +The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-4) are as follows: - _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken) middleware function will verify all HTTP requests that it receives. It is the diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index b05d80b6..36210576 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -14,9 +14,6 @@ import ( "errors" "fmt" "net/http" - "net/url" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -136,13 +133,9 @@ type AuthServerMeta struct { // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http.Client) (*AuthServerMeta, error) { - u, err := url.Parse(metadataURL) - if err != nil { - return nil, err - } // Only allow HTTP for local addresses (testing or development purposes). - if !util.IsLoopback(u.Host) && u.Scheme != "https" { - return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) + if err := checkHTTPSOrLoopback(metadataURL); err != nil { + return nil, fmt.Errorf("metadataURL: %v", err) } asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL, 1<<20) if err != nil { @@ -173,6 +166,8 @@ func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http. // validateAuthServerMetaURLs validates all URL fields in AuthServerMeta // to ensure they don't use dangerous schemes that could enable XSS attacks. +// It also validates that URLs likely to be called by the client use +// HTTPS or are loopback addresses. func validateAuthServerMetaURLs(asm *AuthServerMeta) error { urls := []struct { name string @@ -194,5 +189,22 @@ func validateAuthServerMetaURLs(asm *AuthServerMeta) error { return fmt.Errorf("%s: %w", u.name, err) } } + + urls = []struct { + name string + value string + }{ + {"authorization_endpoint", asm.AuthorizationEndpoint}, + {"token_endpoint", asm.TokenEndpoint}, + {"registration_endpoint", asm.RegistrationEndpoint}, + {"introspection_endpoint", asm.IntrospectionEndpoint}, + } + + for _, u := range urls { + if err := checkHTTPSOrLoopback(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil } diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index 836a4201..d8aeb3c2 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -17,6 +17,8 @@ import ( "net/http" "net/url" "strings" + + "github.com/modelcontextprotocol/go-sdk/internal/util" ) type httpStatusError struct { @@ -78,3 +80,17 @@ func checkURLScheme(u string) error { } return nil } + +func checkHTTPSOrLoopback(addr string) error { + if addr == "" { + return nil + } + u, err := url.Parse(addr) + if err != nil { + return err + } + if !util.IsLoopback(u.Host) && u.Scheme != "https" { + return fmt.Errorf("URL %q does not use HTTPS or is not a loopback address", addr) + } + return nil +} diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index 8b911cad..4680c153 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -98,13 +98,9 @@ func resourceMetadataURL(cs []Challenge) string { // - The authorization_servers field of the resulting metadata is checked for dangerous URL schemes. func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadata(%q)", metadataURL) - u, err := url.Parse(metadataURL) - if err != nil { - return nil, err - } // Only allow HTTP for local addresses (testing or development purposes). - if !util.IsLoopback(u.Host) && u.Scheme != "https" { - return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) + if err := checkHTTPSOrLoopback(metadataURL); err != nil { + return nil, fmt.Errorf("metadataURL: %v", err) } prm, err := getJSON[ProtectedResourceMetadata](ctx, c, metadataURL, 1<<20) if err != nil { @@ -115,9 +111,12 @@ func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, resourceURL) } // Validate the authorization server URLs to prevent XSS attacks (see #526). - for _, u := range prm.AuthorizationServers { + for i, u := range prm.AuthorizationServers { if err := checkURLScheme(u); err != nil { - return nil, err + return nil, fmt.Errorf("authorization_servers[%d]: %v", i, err) + } + if err := checkHTTPSOrLoopback(u); err != nil { + return nil, fmt.Errorf("authorization_servers[%d]: %v", i, err) } } return prm, nil