Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

186 changes: 135 additions & 51 deletions crates/mcp-proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,25 @@ impl Message {

pub struct McpProxy {
transport: InnerTransport,
/// ID of the last forwarded request that has not yet received a response.
///
/// Set when a request (message with both `id` and `method`) is successfully sent to the
/// backend. Cleared when a matching response (same `id`, no `method`) is received.
/// Server-initiated notifications leave this unchanged.
///
/// Used to build a meaningful JSON-RPC error when `ReadError::Fatal` fires: if a request
/// is pending we can correlate the error to it; otherwise we send a notification.
pending_request_id: Option<i32>,
}

/// Error that can occur when sending a message.
#[derive(Debug)]
pub enum SendError {
/// Fatal error - the proxy must stop as the connection is broken.
Fatal {
/// Optional error message to send back when a request ID is detected.
message: Option<Message>,
/// Message to send back to the client: a JSON-RPC error response if there was a pending
/// request ID, or a `$/proxy/serverDisconnected` notification otherwise.
message: Message,
/// The underlying error for logging/debugging.
source: anyhow::Error,
},
Expand All @@ -168,7 +178,13 @@ pub enum SendError {
#[derive(Debug)]
pub enum ReadError {
/// Fatal error - the proxy must stop as the connection is broken.
Fatal(anyhow::Error),
Fatal {
/// Message to send back to the client: a JSON-RPC error response if there was a pending
/// request ID, or a `$/proxy/serverDisconnected` notification otherwise.
message: Message,
/// The underlying error for logging/debugging.
source: anyhow::Error,
},
/// Transient error - the proxy can continue operating.
Transient(anyhow::Error),
}
Expand Down Expand Up @@ -200,7 +216,10 @@ impl McpProxy {
}
};

Ok(McpProxy { transport })
Ok(McpProxy {
transport,
pending_request_id: None,
})
}

/// Send a message to the peer.
Expand All @@ -222,23 +241,31 @@ impl McpProxy {
}

// Try to parse as request first, then as response.
let request_id = match JsonRpcMessage::parse(message) {
Ok(request) => {
match (request.id, request.method) {
let (request_id, is_request) = match JsonRpcMessage::parse(message) {
Ok(msg) => {
let is_request = match (msg.id, msg.method) {
(None, None) => {
warn!(
jsonrpc = request.jsonrpc,
jsonrpc = msg.jsonrpc,
"Sending a malformed JSON-RPC message (missing both `id` and `method`)"
)
);
false
}
(None, Some(method)) => {
debug!(jsonrpc = request.jsonrpc, method, "Sending a notification")
debug!(jsonrpc = msg.jsonrpc, method, "Sending a notification");
false
}
(Some(id), None) => {
debug!(jsonrpc = msg.jsonrpc, id, "Sending a response");
false
}
(Some(id), Some(method)) => {
debug!(jsonrpc = msg.jsonrpc, method, id, "Sending a request");
true
}
(Some(id), None) => debug!(jsonrpc = request.jsonrpc, id, "Sending a response"),
(Some(id), Some(method)) => debug!(jsonrpc = request.jsonrpc, method, id, "Sending a request"),
};

request.id
(msg.id, is_request)
}
Err(error) => {
// Not a JSON-RPC message, try best-effort ID extraction.
Expand All @@ -250,7 +277,7 @@ impl McpProxy {
warn!(error = format!("{error:#}"), "Sending a malformed JSON-RPC message");
}

id
(id, false)
}
};

Expand All @@ -274,8 +301,9 @@ impl McpProxy {
error!(error = format!("{error:#}"), "Couldn't forward request");

let message = if let Some(id) = request_id {
let detail = json_escape_str(&format!("{error:#}"));
let json_rpc_error_response = format!(
r#"{{"jsonrpc":"2.0","id":{id},"error":{{"code":{FORWARD_FAILURE_CODE},"message":"Forward failure: {error:#}"}}}}"#
r#"{{"jsonrpc":"2.0","id":{id},"error":{{"code":{FORWARD_FAILURE_CODE},"message":"Forward failure: {detail}"}}}}"#
);
Some(Message::normalize(json_rpc_error_response))
} else {
Expand All @@ -296,55 +324,31 @@ impl McpProxy {
}
};

return ret;

fn extract_id_best_effort(request_str: &str) -> Option<i32> {
let idx = request_str.find("\"id\"")?;

let mut rest = request_str[idx + "\"id\"".len()..].chars();

loop {
if rest.next()? == ':' {
break;
}
}

let mut acc = String::new();

loop {
match rest.next() {
Some(',') => break,
Some(ch) => acc.push(ch),
None => break,
}
}

acc.parse().ok()
// Track pending request ID for Process/NamedPipe transports so that if the backend
// breaks while we're waiting for the response, we can synthesise a meaningful error.
if ret.is_ok() && is_request {
self.pending_request_id = request_id;
}

return ret;

fn handle_write_result(result: std::io::Result<()>, request_id: Option<i32>) -> Result<(), SendError> {
match result {
Ok(()) => Ok(()),
Err(io_error) => {
// Classify the error.
if is_fatal_io_error(&io_error) {
let message = if let Some(id) = request_id {
let json_rpc_error_response = format!(
r#"{{"jsonrpc":"2.0","id":{id},"error":{{"code":{FORWARD_FAILURE_CODE},"message":"connection broken: {io_error}"}}}}"#
);
Some(Message::normalize(json_rpc_error_response))
} else {
None
};
let message = make_server_disconnected_message(request_id, &io_error.to_string());

Err(SendError::Fatal {
message,
source: anyhow::Error::new(io_error),
})
} else {
let message = if let Some(id) = request_id {
let detail = json_escape_str(&io_error.to_string());
let json_rpc_error_response = format!(
r#"{{"jsonrpc":"2.0","id":{id},"error":{{"code":{FORWARD_FAILURE_CODE},"message":"Forward failure: {io_error}"}}}}"#
r#"{{"jsonrpc":"2.0","id":{id},"error":{{"code":{FORWARD_FAILURE_CODE},"message":"Forward failure: {detail}"}}}}"#
);
Some(Message::normalize(json_rpc_error_response))
} else {
Expand Down Expand Up @@ -397,13 +401,29 @@ impl McpProxy {
match result {
Ok(message) => {
trace!(%message, "Inbound message");

// Clear the pending request ID when the matching response arrives.
// We use best-effort ID extraction to avoid a full JSON parse in the hot path.
// Server-initiated requests with the same ID are rare enough that we accept
// the minor inaccuracy of treating any matching-ID message as the response.
if let Some(pending_id) = self.pending_request_id
&& extract_id_best_effort(message.as_raw()) == Some(pending_id)
{
self.pending_request_id = None;
}

Ok(message)
}
Err(io_error) => {
if is_fatal_io_error(&io_error) {
Err(ReadError::Fatal(anyhow::Error::new(io_error)))
let is_fatal = is_fatal_io_error(&io_error);
let message = make_server_disconnected_message(self.pending_request_id, &io_error.to_string());
self.pending_request_id = None;
let source = anyhow::Error::new(io_error);

if is_fatal {
Err(ReadError::Fatal { message, source })
} else {
Err(ReadError::Transient(anyhow::Error::new(io_error)))
Err(ReadError::Transient(source))
}
}
}
Expand Down Expand Up @@ -684,6 +704,46 @@ fn extract_sse_json_line(body: &str) -> Option<&str> {
body.lines().find_map(|l| l.strip_prefix("data: ").map(|s| s.trim()))
}

/// Escape a string for safe embedding inside a JSON string value.
fn json_escape_str(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c if (c as u32) < 0x20 => {
use std::fmt::Write as _;
write!(out, "\\u{:04x}", c as u32).expect("write to String is infallible");
}
c => out.push(c),
}
}
out
}

/// Build the message to send to the MCP client when the backend connection breaks fatally.
///
/// - If there is a `pending_request_id`, returns a JSON-RPC error response correlating the
/// failure to that outstanding request.
/// - Otherwise, returns a `$/proxy/serverDisconnected` notification so the client knows the
/// server is no longer reachable without having an in-flight request to correlate to.
fn make_server_disconnected_message(pending_request_id: Option<i32>, error_detail: &str) -> Message {
let detail = json_escape_str(error_detail);
let raw = if let Some(id) = pending_request_id {
format!(
r#"{{"jsonrpc":"2.0","id":{id},"error":{{"code":{FORWARD_FAILURE_CODE},"message":"server disconnected: {detail}"}}}}"#
)
} else {
format!(
r#"{{"jsonrpc":"2.0","method":"$/proxy/serverDisconnected","params":{{"message":"server disconnected: {detail}"}}}}"#
)
};
Message::normalize(raw)
}

/// Check if an I/O error is fatal (connection broken)
///
/// For process stdio and named pipe transports, these errors indicate the connection is permanently broken:
Expand All @@ -699,3 +759,27 @@ fn is_fatal_io_error(error: &std::io::Error) -> bool {
std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::UnexpectedEof
)
}

fn extract_id_best_effort(message: &str) -> Option<i32> {
let idx = message.find("\"id\"")?;

let mut rest = message[idx + "\"id\"".len()..].chars();

loop {
if rest.next()? == ':' {
break;
}
}

let mut acc = String::new();

loop {
match rest.next() {
Some(',') | Some('}') => break,
Some(ch) => acc.push(ch),
None => break,
}
}

acc.trim().parse().ok()
}
4 changes: 2 additions & 2 deletions jetsocat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ fn parse_env_variable_as_args(env_var_str: &str) -> Vec<String> {

fn apply_common_flags(cmd: Command) -> Command {
cmd.flag(Flag::new("log-file", FlagType::String).description("Specify filepath for log file"))
.flag(Flag::new("log-term", FlagType::Bool).description("Print logs to stdout instead of log file"))
.flag(Flag::new("log-term", FlagType::Bool).description("Print logs to stderr instead of log file"))
.flag(
Flag::new("color", FlagType::String)
.description("When to enable colored output for logs (possible values: `always`, `never` and `auto`)"),
Expand Down Expand Up @@ -996,7 +996,7 @@ fn setup_logger(logging: &Logging, coloring: Coloring) -> LoggerGuard {
Coloring::Auto => true,
};

let (non_blocking_stdio, guard) = tracing_appender::non_blocking(std::io::stdout()); // FIXME: Should be to stderr.
let (non_blocking_stdio, guard) = tracing_appender::non_blocking(std::io::stderr());
let stdio_layer = fmt::layer().with_writer(non_blocking_stdio).with_ansi(ansi);

(stdio_layer, guard)
Expand Down
9 changes: 3 additions & 6 deletions jetsocat/src/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ pub(crate) async fn run_mcp_proxy(pipe: Pipe, mut mcp_proxy: mcp_proxy::McpProxy
}
Err(mcp_proxy::SendError::Fatal { message, source }) => {
error!(error = format!("{source:#}"), "Fatal error sending message, stopping proxy");

if let Some(msg) = message {
let _ = write_flush_message(&mut writer, msg).await;
}

let _ = write_flush_message(&mut writer, message).await;
return Ok(());
}
}
Expand All @@ -68,8 +64,9 @@ pub(crate) async fn run_mcp_proxy(pipe: Pipe, mut mcp_proxy: mcp_proxy::McpProxy
Err(mcp_proxy::ReadError::Transient(source)) => {
warn!(error = format!("{source:#}"), "Transient error reading from peer");
}
Err(mcp_proxy::ReadError::Fatal(source)) => {
Err(mcp_proxy::ReadError::Fatal { message, source }) => {
error!(error = format!("{source:#}"), "Fatal error reading from peer, stopping proxy");
let _ = write_flush_message(&mut writer, message).await;
return Ok(());
}
}
Expand Down
1 change: 1 addition & 0 deletions testsuite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ serde_json = "1"
serde = { version = "1", features = ["derive"] }
tempfile = "3"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "net", "process"] }
tokio-util = "0.7"
typed-builder = "0.21"
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots"] }

Expand Down
Loading
Loading