Skip to content

Commit 06155ce

Browse files
authored
feat(mcp): elicitation support (#5965)
1 parent 1db4070 commit 06155ce

24 files changed

+992
-22
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
use console::style;
2+
use serde_json::Value;
3+
use std::collections::HashMap;
4+
use std::io::{self, BufRead, IsTerminal, Write};
5+
6+
pub fn collect_elicitation_input(
7+
message: &str,
8+
schema: &Value,
9+
) -> io::Result<Option<HashMap<String, Value>>> {
10+
if !message.is_empty() {
11+
println!("\n{}", style(message).cyan());
12+
}
13+
14+
let properties = match schema.get("properties").and_then(|p| p.as_object()) {
15+
Some(props) => props,
16+
None => return Ok(Some(HashMap::new())),
17+
};
18+
19+
let required: Vec<&str> = schema
20+
.get("required")
21+
.and_then(|r| r.as_array())
22+
.map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
23+
.unwrap_or_default();
24+
25+
let mut data: HashMap<String, Value> = HashMap::new();
26+
27+
for (name, field_schema) in properties {
28+
let is_required = required.contains(&name.as_str());
29+
let field_type = field_schema
30+
.get("type")
31+
.and_then(|t| t.as_str())
32+
.unwrap_or("string");
33+
let description = field_schema.get("description").and_then(|d| d.as_str());
34+
let default = field_schema.get("default");
35+
let enum_values = field_schema.get("enum").and_then(|e| e.as_array());
36+
37+
// makes a little true/false toggle
38+
if field_type == "boolean" {
39+
let label = match description {
40+
Some(desc) => format!("{} ({})", name, desc),
41+
None => name.clone(),
42+
};
43+
let default_bool = default.and_then(|v| v.as_bool()).unwrap_or(false);
44+
45+
match cliclack::confirm(&label)
46+
.initial_value(default_bool)
47+
.interact()
48+
{
49+
Ok(v) => {
50+
data.insert(name.clone(), Value::Bool(v));
51+
}
52+
Err(e) if e.kind() == io::ErrorKind::Interrupted => return Ok(None),
53+
Err(e) => return Err(e),
54+
}
55+
continue;
56+
}
57+
58+
if let Some(options) = enum_values {
59+
let opts: Vec<&str> = options.iter().filter_map(|v| v.as_str()).collect();
60+
println!(" {}: {}", style("Options").dim(), opts.join(", "));
61+
}
62+
63+
print!("{}", style(name).yellow());
64+
if let Some(desc) = description {
65+
print!(" {}", style(format!("({})", desc)).dim());
66+
}
67+
if is_required {
68+
print!("{}", style("*").red());
69+
}
70+
if let Some(def) = default {
71+
print!(" {}", style(format!("[{}]", format_default(def))).dim());
72+
}
73+
print!(": ");
74+
io::stdout().flush()?;
75+
76+
let input = read_line()?;
77+
78+
// Handle Ctrl+C / EOF for cancellation
79+
if input.is_none() {
80+
return Ok(None);
81+
}
82+
let input = input.unwrap();
83+
84+
let value = if input.is_empty() {
85+
default.cloned()
86+
} else {
87+
Some(parse_value(&input, field_type, enum_values))
88+
};
89+
90+
if let Some(v) = value {
91+
if !v.is_null() {
92+
data.insert(name.clone(), v);
93+
}
94+
}
95+
96+
if is_required && !data.contains_key(name) {
97+
println!(
98+
"{}",
99+
style(format!("Required field '{}' is missing", name)).red()
100+
);
101+
return Ok(None);
102+
}
103+
}
104+
105+
println!();
106+
Ok(Some(data))
107+
}
108+
109+
fn read_line() -> io::Result<Option<String>> {
110+
if !std::io::stdin().is_terminal() {
111+
let mut line = String::new();
112+
io::stdin().lock().read_line(&mut line)?;
113+
return Ok(Some(line.trim().to_string()));
114+
}
115+
116+
let mut line = String::new();
117+
match io::stdin().lock().read_line(&mut line) {
118+
Ok(0) => Ok(None), // EOF
119+
Ok(_) => Ok(Some(line.trim().to_string())),
120+
Err(e) if e.kind() == io::ErrorKind::Interrupted => Ok(None),
121+
Err(e) => Err(e),
122+
}
123+
}
124+
125+
fn format_default(value: &Value) -> String {
126+
match value {
127+
Value::String(s) => s.clone(),
128+
Value::Bool(b) => b.to_string(),
129+
Value::Number(n) => n.to_string(),
130+
_ => value.to_string(),
131+
}
132+
}
133+
134+
fn parse_value(input: &str, field_type: &str, enum_values: Option<&Vec<Value>>) -> Value {
135+
if let Some(options) = enum_values {
136+
let valid: Vec<&str> = options.iter().filter_map(|v| v.as_str()).collect();
137+
if valid.contains(&input) {
138+
return Value::String(input.to_string());
139+
}
140+
if let Ok(idx) = input.parse::<usize>() {
141+
if idx > 0 && idx <= valid.len() {
142+
return Value::String(valid[idx - 1].to_string());
143+
}
144+
}
145+
}
146+
147+
match field_type {
148+
"boolean" => {
149+
let lower = input.to_lowercase();
150+
Value::Bool(matches!(lower.as_str(), "true" | "yes" | "y" | "1"))
151+
}
152+
"integer" => input
153+
.parse::<i64>()
154+
.map(|n| Value::Number(n.into()))
155+
.unwrap_or(Value::Null),
156+
"number" => input
157+
.parse::<f64>()
158+
.ok()
159+
.and_then(serde_json::Number::from_f64)
160+
.map(Value::Number)
161+
.unwrap_or(Value::Null),
162+
_ => Value::String(input.to_string()),
163+
}
164+
}

crates/goose-cli/src/session/export.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,20 @@ pub fn message_to_markdown(message: &Message, export_all_content: bool) -> Strin
349349
tool_name
350350
));
351351
}
352+
ActionRequiredData::Elicitation { message, .. } => {
353+
md.push_str(&format!(
354+
"**Action Required** (elicitation): {}\n\n",
355+
message
356+
));
357+
}
358+
ActionRequiredData::ElicitationResponse { id, user_data } => {
359+
md.push_str(&format!(
360+
"**Action Required** (elicitation_response): {}\n```json\n{}\n```\n\n",
361+
id,
362+
serde_json::to_string_pretty(user_data)
363+
.unwrap_or_else(|_| "{}".to_string())
364+
));
365+
}
352366
},
353367
MessageContent::Text(text) => {
354368
md.push_str(&text.text);

crates/goose-cli/src/session/mod.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod builder;
22
mod completion;
3+
mod elicitation;
34
mod export;
45
mod input;
56
mod output;
@@ -865,6 +866,18 @@ impl CliSession {
865866
}
866867
});
867868

869+
let elicitation_request = message.content.iter().find_map(|content| {
870+
if let MessageContent::ActionRequired(action) = content {
871+
if let ActionRequiredData::Elicitation { id, message, requested_schema } = &action.data {
872+
Some((id.clone(), message.clone(), requested_schema.clone()))
873+
} else {
874+
None
875+
}
876+
} else {
877+
None
878+
}
879+
});
880+
868881
if let Some((id, _tool_name, _arguments, security_prompt)) = tool_call_confirmation {
869882
output::hide_thinking();
870883

@@ -924,6 +937,48 @@ impl CliSession {
924937
}).await;
925938
}
926939
}
940+
else if let Some((elicitation_id, elicitation_message, schema)) = elicitation_request {
941+
output::hide_thinking();
942+
let _ = progress_bars.hide();
943+
944+
match elicitation::collect_elicitation_input(&elicitation_message, &schema) {
945+
Ok(Some(user_data)) => {
946+
let user_data_value = serde_json::to_value(user_data)
947+
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
948+
949+
let response_message = Message::user()
950+
.with_content(MessageContent::action_required_elicitation_response(
951+
elicitation_id.clone(),
952+
user_data_value,
953+
))
954+
.with_visibility(false, true);
955+
956+
self.messages.push(response_message.clone());
957+
// Elicitation responses return an empty stream - the response
958+
// unblocks the waiting tool call via ActionRequiredManager
959+
let _ = self
960+
.agent
961+
.reply(
962+
response_message,
963+
session_config.clone(),
964+
Some(cancel_token.clone()),
965+
)
966+
.await?;
967+
}
968+
Ok(None) => {
969+
output::render_text("Information request cancelled.", Some(Color::Yellow), true);
970+
cancel_token_clone.cancel();
971+
drop(stream);
972+
break;
973+
}
974+
Err(e) => {
975+
output::render_error(&format!("Failed to collect input: {}", e));
976+
cancel_token_clone.cancel();
977+
drop(stream);
978+
break;
979+
}
980+
}
981+
}
927982
else {
928983
for content in &message.content {
929984
if let MessageContent::ToolRequest(tool_request) = content {

crates/goose-cli/src/session/output.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ pub fn render_message(message: &Message, debug: bool) {
172172
ActionRequiredData::ToolConfirmation { tool_name, .. } => {
173173
println!("action_required(tool_confirmation): {}", tool_name)
174174
}
175+
ActionRequiredData::Elicitation { message, .. } => {
176+
println!("action_required(elicitation): {}", message)
177+
}
178+
ActionRequiredData::ElicitationResponse { id, .. } => {
179+
println!("action_required(elicitation_response): {}", id)
180+
}
175181
},
176182
MessageContent::Text(text) => print_markdown(&text.text, theme),
177183
MessageContent::ToolRequest(req) => render_tool_request(req, theme, debug),
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
use anyhow::Result;
2+
use serde_json::Value;
3+
use std::collections::HashMap;
4+
use std::sync::Arc;
5+
use std::time::Duration;
6+
use tokio::sync::{mpsc, Mutex, RwLock};
7+
use tokio::time::timeout;
8+
use tracing::warn;
9+
use uuid::Uuid;
10+
11+
use crate::conversation::message::{Message, MessageContent};
12+
13+
struct PendingRequest {
14+
response_tx: Option<tokio::sync::oneshot::Sender<Value>>,
15+
}
16+
17+
pub struct ActionRequiredManager {
18+
pending: Arc<RwLock<HashMap<String, Arc<Mutex<PendingRequest>>>>>,
19+
request_tx: mpsc::UnboundedSender<Message>,
20+
pub request_rx: Mutex<mpsc::UnboundedReceiver<Message>>,
21+
}
22+
23+
impl ActionRequiredManager {
24+
fn new() -> Self {
25+
let (request_tx, request_rx) = mpsc::unbounded_channel();
26+
Self {
27+
pending: Arc::new(RwLock::new(HashMap::new())),
28+
request_tx,
29+
request_rx: Mutex::new(request_rx),
30+
}
31+
}
32+
33+
pub fn global() -> &'static Self {
34+
static INSTANCE: once_cell::sync::Lazy<ActionRequiredManager> =
35+
once_cell::sync::Lazy::new(ActionRequiredManager::new);
36+
&INSTANCE
37+
}
38+
39+
pub async fn request_and_wait(
40+
&self,
41+
message: String,
42+
schema: Value,
43+
timeout_duration: Duration,
44+
) -> Result<Value> {
45+
let id = Uuid::new_v4().to_string();
46+
let (tx, rx) = tokio::sync::oneshot::channel();
47+
let pending_request = PendingRequest {
48+
response_tx: Some(tx),
49+
};
50+
51+
self.pending
52+
.write()
53+
.await
54+
.insert(id.clone(), Arc::new(Mutex::new(pending_request)));
55+
56+
let action_required_message = Message::assistant().with_content(
57+
MessageContent::action_required_elicitation(id.clone(), message, schema),
58+
);
59+
60+
if let Err(e) = self.request_tx.send(action_required_message) {
61+
warn!("Failed to send action required message: {}", e);
62+
}
63+
64+
let result = match timeout(timeout_duration, rx).await {
65+
Ok(Ok(user_data)) => Ok(user_data),
66+
Ok(Err(_)) => {
67+
warn!("Response channel closed for request: {}", id);
68+
Err(anyhow::anyhow!("Response channel closed"))
69+
}
70+
Err(_) => {
71+
warn!("Timeout waiting for response: {}", id);
72+
Err(anyhow::anyhow!("Timeout waiting for user response"))
73+
}
74+
};
75+
76+
self.pending.write().await.remove(&id);
77+
78+
result
79+
}
80+
81+
pub async fn submit_response(&self, request_id: String, user_data: Value) -> Result<()> {
82+
let pending_arc = {
83+
let pending = self.pending.read().await;
84+
pending
85+
.get(&request_id)
86+
.cloned()
87+
.ok_or_else(|| anyhow::anyhow!("Request not found: {}", request_id))?
88+
};
89+
90+
let mut pending = pending_arc.lock().await;
91+
if let Some(tx) = pending.response_tx.take() {
92+
if tx.send(user_data).is_err() {
93+
warn!("Failed to send response through oneshot channel");
94+
}
95+
}
96+
97+
Ok(())
98+
}
99+
}

0 commit comments

Comments
 (0)