Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions auth/authorization_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@ type AuthorizationCodeHandlerConfig struct {
// It should return the authorization code and state once the Authorization Server
// redirects back to the RedirectURL.
AuthorizationCodeFetcher func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error)

// Client is an optional HTTP client to use for HTTP requests.
// It is used for the following requests:
// - Fetching Protected Resource Metadata
// - Fetching Authorization Server Metadata
// - Registering a client dynamically
// - Exchanging an authorization code for an access token
// - Refreshing an access token
// Custom clients can include additional security configurations,
// such as SSRF protections, see
// https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#server-side-request-forgery-ssrf
// If not provided, http.DefaultClient will be used.
Client *http.Client
}

// AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses
Expand Down Expand Up @@ -166,6 +179,9 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho
// it should have been set by now. Otherwise, it is required.
return nil, errors.New("RedirectURL is required")
}
if config.Client == nil {
config.Client = http.DefaultClient
}
return &AuthorizationCodeHandler{config: config}, nil
}

Expand Down Expand Up @@ -280,7 +296,7 @@ func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Cont
// Use MCP server URL as the resource URI per
// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri.
for _, url := range protectedResourceMetadataURLs(resourceMetadataURLFromChallenges(wwwChallenges), mcpServerURL) {
prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, http.DefaultClient)
prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, h.config.Client)
if err != nil {
errs = append(errs, err)
continue
Expand Down Expand Up @@ -359,7 +375,7 @@ func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL {
func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata) (*oauthex.AuthServerMeta, error) {
authServerURL := prm.AuthorizationServers[0]
for _, u := range authorizationServerMetadataURLs(authServerURL) {
asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient)
asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, h.config.Client)
if err != nil {
return nil, fmt.Errorf("failed to get authorization server metadata: %w", err)
}
Expand Down Expand Up @@ -488,7 +504,7 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm *
// 3. Attempt to use dynamic client registration.
dcrCfg := h.config.DynamicClientRegistrationConfig
if dcrCfg != nil && asm.RegistrationEndpoint != "" {
regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient)
regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, h.config.Client)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
Expand Down Expand Up @@ -542,10 +558,11 @@ func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context
oauth2.VerifierOption(authResult.usedCodeVerifier),
oauth2.SetAuthURLParam("resource", resourceURL),
}
token, err := cfg.Exchange(ctx, authResult.Code, opts...)
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, h.config.Client)
token, err := cfg.Exchange(clientCtx, authResult.Code, opts...)
if err != nil {
return fmt.Errorf("token exchange failed: %w", err)
}
h.tokenSource = cfg.TokenSource(ctx, token)
h.tokenSource = cfg.TokenSource(clientCtx, token)
return nil
}
81 changes: 58 additions & 23 deletions auth/authorization_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ func TestAuthorize_ForbiddenUnhandledError(t *testing.T) {
"WWW-Authenticate",
"Bearer error=invalid_token",
)
handler := &AuthorizationCodeHandler{} // No config needed for this test.
err := handler.Authorize(t.Context(), req, resp)
handler, err := NewAuthorizationCodeHandler(validConfig())
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler failed: %v", err)
}
err = handler.Authorize(t.Context(), req, resp)
if err != nil {
t.Fatalf("Authorize() failed: %v", err)
}
Expand Down Expand Up @@ -200,16 +203,7 @@ func TestNewAuthorizationCodeHandler_Success(t *testing.T) {
}

func TestNewAuthorizationCodeHandler_Error(t *testing.T) {
validConfig := func() *AuthorizationCodeHandlerConfig {
return &AuthorizationCodeHandlerConfig{
ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"},
RedirectURL: "https://example.com/callback",
AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) {
return nil, nil
},
}
}
// Ensure the base config is valid
// Ensure the base config is valid.
if _, err := NewAuthorizationCodeHandler(validConfig()); err != nil {
t.Fatalf("NewAuthorizationCodeHandler failed: %v", err)
}
Expand Down Expand Up @@ -324,7 +318,11 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) {
}

func TestGetProtectedResourceMetadata_Success(t *testing.T) {
handler := &AuthorizationCodeHandler{} // No config needed for this method
handler, err := NewAuthorizationCodeHandler(validConfig())
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

t.Fatalf("NewAuthorizationCodeHandler() error = %v", err)
}

pathForChallenge := "/protected-resource"

tests := []struct {
Expand Down Expand Up @@ -398,8 +396,11 @@ func TestGetProtectedResourceMetadata_Success(t *testing.T) {
}

func TestGetProtectedResourceMetadata_Backcompat(t *testing.T) {
handler, err := NewAuthorizationCodeHandler(validConfig())
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler() error = %v", err)
}
var challenges []oauthex.Challenge
handler := &AuthorizationCodeHandler{} // No config needed for this method
got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, "http://localhost:1234/resource")
if err != nil {
t.Fatalf("getProtectedResourceMetadata() error = %v", err)
Expand All @@ -423,8 +424,11 @@ func TestGetProtectedResourceMetadata_Error(t *testing.T) {
ScopesSupported: []string{"read", "write"},
}
mux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(metadata))
handler, err := NewAuthorizationCodeHandler(validConfig())
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler() error = %v", err)
}
var challenges []oauthex.Challenge
handler := &AuthorizationCodeHandler{} // No config needed for this method
got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, server.URL+"/resource")
if err == nil || !strings.Contains(err.Error(), "authorization servers") {
t.Errorf("getProtectedResourceMetadata() = %v, want error containing \"authorization servers\"", err)
Expand All @@ -435,7 +439,10 @@ func TestGetProtectedResourceMetadata_Error(t *testing.T) {
}

func TestGetAuthServerMetadata(t *testing.T) {
handler := &AuthorizationCodeHandler{} // No config needed for this method
handler, err := NewAuthorizationCodeHandler(validConfig())
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler() error = %v", err)
}

tests := []struct {
name string
Expand Down Expand Up @@ -563,11 +570,11 @@ func TestHandleRegistration(t *testing.T) {
ClientIDMetadataDocumentSupported: true,
},
handlerConfig: &AuthorizationCodeHandlerConfig{
ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"},
ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com/metadata.json"},
},
want: &resolvedClientConfig{
registrationType: registrationTypeClientIDMetadataDocument,
clientID: "https://client.example.com",
clientID: "https://client.example.com/metadata.json",
},
},
{
Expand Down Expand Up @@ -597,7 +604,7 @@ func TestHandleRegistration(t *testing.T) {
{
name: "NoneSupported",
handlerConfig: &AuthorizationCodeHandlerConfig{
ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"},
ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com/metadata.json"},
},
wantError: true,
},
Expand All @@ -607,7 +614,14 @@ func TestHandleRegistration(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{RegistrationConfig: tt.serverConfig})
s.Start(t)
handler := &AuthorizationCodeHandler{config: tt.handlerConfig}
tt.handlerConfig.AuthorizationCodeFetcher = func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) {
return nil, nil
}
tt.handlerConfig.RedirectURL = "https://example.com/callback"
handler, err := NewAuthorizationCodeHandler(tt.handlerConfig)
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler() error = %v, want nil", err)
}
asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{
AuthorizationServers: []string{s.URL()},
})
Expand Down Expand Up @@ -644,11 +658,20 @@ func TestDynamicRegistration(t *testing.T) {
},
})
s.Start(t)
handler := &AuthorizationCodeHandler{config: &AuthorizationCodeHandlerConfig{
handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{
DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{
Metadata: &oauthex.ClientRegistrationMetadata{},
Metadata: &oauthex.ClientRegistrationMetadata{
RedirectURIs: []string{"https://example.com/callback"},
},
},
RedirectURL: "https://example.com/callback",
AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) {
return nil, nil
},
}}
})
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler() error = %v", err)
}
asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{
AuthorizationServers: []string{s.URL()},
})
Expand All @@ -672,3 +695,15 @@ func TestDynamicRegistration(t *testing.T) {
t.Errorf("handleRegistration() authStyle = %v, want %v", got.authStyle, oauth2.AuthStyleInHeader)
}
}

// validConfig for test to create an AuthorizationCodeHandler using its constructor.
// Values that are relevant to the test should be set explicitly.
func validConfig() *AuthorizationCodeHandlerConfig {
return &AuthorizationCodeHandlerConfig{
ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"},
RedirectURL: "https://example.com/callback",
AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) {
return nil, nil
},
}
}
28 changes: 25 additions & 3 deletions docs/protocol.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
1. [Security](#security)
1. [Confused Deputy](#confused-deputy)
1. [Token Passthrough](#token-passthrough)
1. [Server-Side Request Forgery (SSRF)](#server-side-request-forgery-(ssrf))
1. [Session Hijacking](#session-hijacking)
1. [Utilities](#utilities)
1. [Cancellation](#cancellation)
Expand Down Expand Up @@ -363,7 +364,7 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi

### Confused Deputy

The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation),
The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation),
obtaining user consent for dynamically registered clients, is mostly the
responsibility of the MCP Proxy server implementation. The SDK client does
generate cryptographically secure random `state` values for each authorization
Expand All @@ -372,15 +373,36 @@ Mismatched state values will result in an error.

### Token Passthrough

The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure
The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure
of tokens and is the responsibility of the
[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier)
provided to
[`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken).

### Server-Side Request Forgery (SSRF)

The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-3) are as follows:

- _Enforce HTTPS_. The OAuth helpers provided by the SDK reject the `http://` URLs
except loopback addresses (`localhost`, `127.0.0.1`, `::1`).

- _Block Private IP Ranges_. The OAuth helpers provided by the SDK allow passing
a custom `http.Client`. Developers are advised to customize the client it with
appropriate network protections, including IP range blocking. The SDK does not provide
this capability out of the box.

- _Validate Redirect Targets_. Similarly to previous point, customized `http.Client`
can be used to validate network hops. The SDK does not provide this capability out
of the box.

- _Use Egress Proxies_. This is out of scope for the SDK and can be configured separately.

- _DNS Resolution Considerations_. The SDK has DNS rebinding protection on the server side which is enabled by default. For the client side, consider providing
a custom `http.Client` that would implement DNS pinning.

### Session Hijacking

The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows:
The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-4) are as follows:

- _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken)
middleware function will verify all HTTP requests that it receives. It is the
Expand Down
27 changes: 24 additions & 3 deletions internal/docs/protocol.src.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi

### Confused Deputy

The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation),
The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation),
obtaining user consent for dynamically registered clients, is mostly the
responsibility of the MCP Proxy server implementation. The SDK client does
generate cryptographically secure random `state` values for each authorization
Expand All @@ -298,15 +298,36 @@ Mismatched state values will result in an error.

### Token Passthrough

The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure
The [mitigation](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure
of tokens and is the responsibility of the
[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier)
provided to
[`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken).

### Server-Side Request Forgery (SSRF)

The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-3) are as follows:

- _Enforce HTTPS_. The OAuth helpers provided by the SDK reject the `http://` URLs
except loopback addresses (`localhost`, `127.0.0.1`, `::1`).

- _Block Private IP Ranges_. The OAuth helpers provided by the SDK allow passing
a custom `http.Client`. Developers are advised to customize the client it with
appropriate network protections, including IP range blocking. The SDK does not provide
this capability out of the box.

- _Validate Redirect Targets_. Similarly to previous point, customized `http.Client`
can be used to validate network hops. The SDK does not provide this capability out
of the box.

- _Use Egress Proxies_. This is out of scope for the SDK and can be configured separately.

- _DNS Resolution Considerations_. The SDK has DNS rebinding protection on the server side which is enabled by default. For the client side, consider providing
a custom `http.Client` that would implement DNS pinning.

### Session Hijacking

The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows:
The [mitigations](https://modelcontextprotocol.io/docs/tutorials/security/security_best_practices#mitigation-4) are as follows:

- _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken)
middleware function will verify all HTTP requests that it receives. It is the
Expand Down
30 changes: 21 additions & 9 deletions oauthex/auth_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ import (
"errors"
"fmt"
"net/http"
"net/url"

"github.com/modelcontextprotocol/go-sdk/internal/util"
)

// AuthServerMeta represents the metadata for an OAuth 2.0 authorization server,
Expand Down Expand Up @@ -136,13 +133,9 @@ type AuthServerMeta struct {
//
// [RFC 8414]: https://tools.ietf.org/html/rfc8414
func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http.Client) (*AuthServerMeta, error) {
u, err := url.Parse(metadataURL)
if err != nil {
return nil, err
}
// Only allow HTTP for local addresses (testing or development purposes).
if !util.IsLoopback(u.Host) && u.Scheme != "https" {
return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL)
if err := checkHTTPSOrLoopback(metadataURL); err != nil {
return nil, fmt.Errorf("metadataURL: %v", err)
}
asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL, 1<<20)
if err != nil {
Expand Down Expand Up @@ -173,6 +166,8 @@ func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http.

// validateAuthServerMetaURLs validates all URL fields in AuthServerMeta
// to ensure they don't use dangerous schemes that could enable XSS attacks.
// It also validates that URLs likely to be called by the client use
// HTTPS or are loopback addresses.
func validateAuthServerMetaURLs(asm *AuthServerMeta) error {
urls := []struct {
name string
Expand All @@ -194,5 +189,22 @@ func validateAuthServerMetaURLs(asm *AuthServerMeta) error {
return fmt.Errorf("%s: %w", u.name, err)
}
}

urls = []struct {
name string
value string
}{
{"authorization_endpoint", asm.AuthorizationEndpoint},
{"token_endpoint", asm.TokenEndpoint},
{"registration_endpoint", asm.RegistrationEndpoint},
{"introspection_endpoint", asm.IntrospectionEndpoint},
}

for _, u := range urls {
if err := checkHTTPSOrLoopback(u.value); err != nil {
return fmt.Errorf("%s: %w", u.name, err)
}
}

return nil
}
Loading
Loading