diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index b8d4f3f4..2236590a 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -7,16 +7,16 @@ use std::{ use async_trait::async_trait; use oauth2::{ AsyncHttpClient, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, - EmptyExtraTokenFields, HttpClientError, HttpRequest, HttpResponse, PkceCodeChallenge, - PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, StandardTokenResponse, - TokenResponse, TokenUrl, - basic::{BasicClient, BasicTokenType}, + EmptyExtraTokenFields, ExtraTokenFields, HttpClientError, HttpRequest, HttpResponse, + PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, + StandardTokenResponse, TokenResponse, TokenUrl, basic::BasicTokenType, }; use reqwest::{ Client as HttpClient, IntoUrl, StatusCode, Url, header::{AUTHORIZATION, WWW_AUTHENTICATE}, }; use serde::{Deserialize, Serialize}; +use serde_json::Value; use thiserror::Error; use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, warn}; @@ -126,6 +126,32 @@ pub struct StoredAuthorizationState { pub created_at: u64, } +/// A transparent wrapper around a JSON object that captures any extra fields returned by the +/// authorization server during token exchange that are not part of the standard OAuth 2.0 token +/// response. +/// +/// OAuth providers may include non-standard fields alongside the +/// standard OAuth fields. Those fields are collected here so callers +/// can inspect them without losing data. +/// +/// The inner [`HashMap`] maps field names to their raw JSON values. +/// +/// # Accessing extra fields +/// +/// Extra fields are available through [`StandardTokenResponse::extra_fields()`], which returns a +/// reference to this struct. Use the inner map (`.0`) to look up individual fields by name: +/// +/// ```rust,ignore +/// // Obtain the token response from the AuthorizationManager, then: +/// if let Some(value) = token_response.extra_fields().0.get("vendorSpecificField") { +/// println!("vendorSpecificField = {value}"); +/// } +/// ``` +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct VendorExtraTokenFields(pub HashMap); + +impl ExtraTokenFields for VendorExtraTokenFields {} + impl StoredAuthorizationState { pub fn new(pkce_verifier: &PkceCodeVerifier, csrf_token: &CsrfToken) -> Self { Self { @@ -345,7 +371,18 @@ pub struct OAuthClientConfig { // add type aliases for oauth2 types type OAuthErrorResponse = oauth2::StandardErrorResponse; -pub type OAuthTokenResponse = StandardTokenResponse; + +/// The token response returned by the authorization server after a successful OAuth 2.0 flow. +/// +/// This is a [`StandardTokenResponse`] parameterised with [`VendorExtraTokenFields`], which means +/// it carries both the standard OAuth fields and +/// any vendor-specific fields the server may have included in the JSON response body. +/// +/// # Accessing vendor-specific fields +/// +/// Call [`extra_fields()`][OAuthTokenResponse::extra_fields] to obtain a reference to the +/// [`VendorExtraTokenFields`] wrapper, then index into its inner map. +pub type OAuthTokenResponse = StandardTokenResponse; type OAuthTokenIntrospection = oauth2::StandardTokenIntrospectionResponse; type OAuthRevocableToken = oauth2::StandardRevocableToken; @@ -581,7 +618,7 @@ impl AuthorizationManager { let redirect_url = RedirectUrl::new(config.redirect_uri.clone()) .map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?; - let mut client_builder = BasicClient::new(client_id.clone()) + let mut client_builder: OAuthClient = oauth2::Client::new(client_id.clone()) .set_auth_uri(auth_url) .set_token_uri(token_url) .set_redirect_uri(redirect_url); @@ -882,7 +919,7 @@ impl AuthorizationManager { &self, code: &str, csrf_token: &str, - ) -> Result, AuthError> { + ) -> Result { debug!("start exchange code for token: {:?}", code); let oauth_client = self .oauth_client @@ -1017,9 +1054,7 @@ impl AuthorizationManager { } /// refresh access token - pub async fn refresh_token( - &self, - ) -> Result, AuthError> { + pub async fn refresh_token(&self) -> Result { let oauth_client = self .oauth_client .as_ref() @@ -1551,7 +1586,7 @@ impl AuthorizationSession { &self, code: &str, csrf_token: &str, - ) -> Result, AuthError> { + ) -> Result { self.auth_manager .exchange_code_for_token(code, csrf_token) .await @@ -1876,6 +1911,7 @@ mod tests { AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore, OAuthClientConfig, ScopeUpgradeConfig, StateStore, StoredAuthorizationState, is_https_url, }; + use crate::transport::auth::VendorExtraTokenFields; // -- url helpers -- @@ -2686,11 +2722,13 @@ mod tests { use super::{OAuthTokenResponse, StoredCredentials}; fn make_token_response(access_token: &str, expires_in_secs: Option) -> OAuthTokenResponse { - use oauth2::{AccessToken, EmptyExtraTokenFields, basic::BasicTokenType}; + use oauth2::{AccessToken, basic::BasicTokenType}; let mut resp = OAuthTokenResponse::new( AccessToken::new(access_token.to_string()), BasicTokenType::Bearer, - EmptyExtraTokenFields {}, + VendorExtraTokenFields { + ..Default::default() + }, ); if let Some(secs) = expires_in_secs { resp.set_expires_in(Some(&std::time::Duration::from_secs(secs)));