From aa46694ddffe5d185525fce6fbe94bdaf52a6fef Mon Sep 17 00:00:00 2001 From: Adwait Kumar Singh Date: Thu, 15 Jan 2026 00:29:16 +0530 Subject: [PATCH 1/2] Revert "Add MCP proxy server support for prompts" This reverts commit ebd73102f8ef36a13ae4040763b46cf8ce0db4bb. --- .../smithy/java/mcp/server/McpService.java | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index ca7c10ccf..37d7e8f76 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -220,19 +220,8 @@ private JsonRpcResponse handleInitialize(JsonRpcRequest req) { } private JsonRpcResponse handlePromptsList(JsonRpcRequest req) { - var allPrompts = new ArrayList<>(prompts.values().stream().map(Prompt::promptInfo).toList()); - - // Add prompts from proxy servers - for (McpServerProxy proxy : proxies.values()) { - var response = proxy.rpc(req).join(); - if (response.getError() == null) { - var proxyPrompts = response.getResult().asShape(ListPromptsResult.builder()).getPrompts(); - allPrompts.addAll(proxyPrompts); - } - } - var result = ListPromptsResult.builder() - .prompts(allPrompts) + .prompts(prompts.values().stream().map(Prompt::promptInfo).toList()) .build(); return createSuccessResponse(req.getId(), result); } @@ -243,12 +232,12 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) { var prompt = prompts.get(normalize(promptName)); - if (prompt != null) { - var result = promptProcessor.buildPromptResult(prompt, promptArguments); - return createSuccessResponse(req.getId(), result); + if (prompt == null) { + throw new RuntimeException("Prompt not found: " + promptName); } - throw new RuntimeException("Prompt not found: " + promptName); + var result = promptProcessor.buildPromptResult(prompt, promptArguments); + return createSuccessResponse(req.getId(), result); } private JsonRpcResponse handleToolsList(JsonRpcRequest req, ProtocolVersion protocolVersion) { From 0bf18286d74432286a68a05dbd7c1a82e16beec8 Mon Sep 17 00:00:00 2001 From: Adwait Kumar Singh Date: Thu, 15 Jan 2026 02:08:36 +0530 Subject: [PATCH 2/2] Support Prompts for MCP Proxies --- .../java/mcp/server/McpServerProxy.java | 21 ++ .../smithy/java/mcp/server/McpService.java | 33 ++- .../amazon/smithy/java/mcp/server/Prompt.java | 222 +++++++++++++++++- .../java/mcp/server/PromptProcessor.java | 164 ------------- .../java/mcp/server/PromptProcessorTest.java | 140 ----------- .../smithy/java/mcp/server/PromptTest.java | 175 ++++++++++++++ 6 files changed, 446 insertions(+), 309 deletions(-) delete mode 100644 mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/PromptProcessor.java delete mode 100644 mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptProcessorTest.java create mode 100644 mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptTest.java diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java index 623a96317..5997c6d93 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java @@ -5,6 +5,8 @@ package software.amazon.smithy.java.mcp.server; +import static software.amazon.smithy.java.mcp.model.ListPromptsResult.builder; + import java.util.List; import java.util.Objects; import java.util.concurrent.CompletableFuture; @@ -17,6 +19,7 @@ import software.amazon.smithy.java.mcp.model.JsonRpcRequest; import software.amazon.smithy.java.mcp.model.JsonRpcResponse; import software.amazon.smithy.java.mcp.model.ListToolsResult; +import software.amazon.smithy.java.mcp.model.PromptInfo; import software.amazon.smithy.java.mcp.model.ToolInfo; public abstract class McpServerProxy { @@ -46,6 +49,24 @@ public List listTools() { }).join(); } + public List listPrompts() { + JsonRpcRequest request = JsonRpcRequest.builder() + .method("prompts/list") + .id(generateRequestId()) + .jsonrpc("2.0") + .build(); + return rpc(request).thenApply(response -> { + if (response.getError() != null) { + throw new RuntimeException("Error listing prompts: " + response.getError().getMessage()); + } + return response.getResult() + .asShape(builder()) + .getPrompts() + .stream() + .toList(); + }).join(); + } + public void initialize( Consumer notificationConsumer, JsonRpcRequest initializeRequest, diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index 37d7e8f76..c0440710a 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -52,6 +52,7 @@ import software.amazon.smithy.java.mcp.model.JsonRpcResponse; import software.amazon.smithy.java.mcp.model.ListPromptsResult; import software.amazon.smithy.java.mcp.model.ListToolsResult; +import software.amazon.smithy.java.mcp.model.PromptInfo; import software.amazon.smithy.java.mcp.model.Prompts; import software.amazon.smithy.java.mcp.model.ServerInfo; import software.amazon.smithy.java.mcp.model.TextContent; @@ -84,7 +85,6 @@ public final class McpService { private final Map tools; private final Map prompts; - private final PromptProcessor promptProcessor; private final String serviceName; private final String version; private final Map proxies; @@ -107,8 +107,7 @@ public final class McpService { this.schemaIndex = SchemaIndex.compose(services.values().stream().map(Service::schemaIndex).toArray(SchemaIndex[]::new)); this.tools = createTools(services); - this.prompts = PromptLoader.loadPrompts(services.values()); - this.promptProcessor = new PromptProcessor(); + this.prompts = new ConcurrentHashMap<>(PromptLoader.loadPrompts(services.values())); this.serviceName = name; this.version = version; this.proxies = proxyList.stream().collect(Collectors.toMap(McpServerProxy::name, p -> p)); @@ -236,7 +235,7 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) { throw new RuntimeException("Prompt not found: " + promptName); } - var result = promptProcessor.buildPromptResult(prompt, promptArguments); + var result = prompt.getPromptResult(promptArguments, req.getId()); return createSuccessResponse(req.getId(), result); } @@ -342,6 +341,19 @@ public void initializeProxies(Consumer responseWriter) { for (var toolInfo : proxyTools) { tools.put(toolInfo.getName(), new Tool(toolInfo, proxy.name(), proxy)); } + + // Fetch and register prompts from proxy + try { + List proxyPrompts = proxy.listPrompts(); + for (var promptInfo : proxyPrompts) { + var normalizedName = PromptLoader.normalize(promptInfo.getName()); + if (!prompts.containsKey(normalizedName)) { + prompts.put(normalizedName, new Prompt(promptInfo, proxy)); + } + } + } catch (Exception e) { + LOG.error("Failed to fetch prompts from proxy: " + proxy.name(), e); + } } } } @@ -376,6 +388,19 @@ public void addNewProxy(McpServerProxy mcpServerProxy, Consumer } catch (Exception e) { LOG.error("Failed to fetch tools from proxy", e); } + + // Also fetch prompts from the new proxy + try { + List proxyPrompts = mcpServerProxy.listPrompts(); + for (var promptInfo : proxyPrompts) { + var normalizedName = PromptLoader.normalize(promptInfo.getName()); + if (!prompts.containsKey(normalizedName)) { + prompts.put(normalizedName, new Prompt(promptInfo, mcpServerProxy)); + } + } + } catch (Exception e) { + LOG.error("Failed to fetch prompts from proxy: " + mcpServerProxy.name(), e); + } } /** diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/Prompt.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/Prompt.java index e905f1de1..51e54dded 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/Prompt.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/Prompt.java @@ -5,6 +5,226 @@ package software.amazon.smithy.java.mcp.server; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import software.amazon.smithy.java.core.serde.document.Document; +import software.amazon.smithy.java.mcp.model.GetPromptResult; +import software.amazon.smithy.java.mcp.model.JsonRpcRequest; +import software.amazon.smithy.java.mcp.model.PromptArgument; import software.amazon.smithy.java.mcp.model.PromptInfo; +import software.amazon.smithy.java.mcp.model.PromptMessage; +import software.amazon.smithy.java.mcp.model.PromptMessageContent; +import software.amazon.smithy.java.mcp.model.PromptMessageContentType; +import software.amazon.smithy.java.mcp.model.PromptRole; +import software.amazon.smithy.utils.SmithyUnstableApi; -public record Prompt(PromptInfo promptInfo, String promptTemplate) {} +/** + * Represents a prompt that can be either local (with a template) or proxied to a remote MCP server. + */ +@SmithyUnstableApi +public final class Prompt { + + private static final Pattern PROMPT_ARGUMENT_PLACEHOLDER = Pattern.compile("\\{\\{(\\w+)\\}\\}"); + + private final PromptInfo promptInfo; + private final String promptTemplate; + private final McpServerProxy proxy; + + /** + * Creates a local prompt with a template. + * + * @param promptInfo The prompt metadata + * @param promptTemplate The template string containing {{placeholder}} patterns + */ + public Prompt(PromptInfo promptInfo, String promptTemplate) { + this.promptInfo = promptInfo; + this.promptTemplate = promptTemplate; + this.proxy = null; + } + + /** + * Creates a proxy prompt that delegates to a remote MCP server. + * + * @param promptInfo The prompt metadata + * @param proxy The MCP server proxy to delegate to + */ + public Prompt(PromptInfo promptInfo, McpServerProxy proxy) { + this.promptInfo = promptInfo; + this.promptTemplate = null; + this.proxy = proxy; + } + + /** + * @return The prompt metadata + */ + public PromptInfo promptInfo() { + return promptInfo; + } + + /** + * Gets the prompt result, either by processing the local template or by + * forwarding the request to the proxy server. + * + * @param arguments Document containing argument values for template substitution + * @param requestId The request ID to use for proxy calls (may be null for local prompts) + * @return GetPromptResult with processed template or proxy response + */ + public GetPromptResult getPromptResult(Document arguments, Document requestId) { + if (proxy != null) { + return delegateToProxy(arguments, requestId); + } + return buildLocalPromptResult(arguments); + } + + /** + * Delegates the prompt request to the proxy server via RPC. + */ + private GetPromptResult delegateToProxy(Document arguments, Document requestId) { + Map params = new HashMap<>(); + params.put("name", Document.of(promptInfo.getName())); + if (arguments != null) { + params.put("arguments", arguments); + } + + JsonRpcRequest request = JsonRpcRequest.builder() + .method("prompts/get") + .id(requestId) + .params(Document.of(params)) + .jsonrpc("2.0") + .build(); + + return proxy.rpc(request).thenApply(response -> { + if (response.getError() != null) { + throw new RuntimeException("Error getting prompt: " + response.getError().getMessage()); + } + return response.getResult().asShape(GetPromptResult.builder()); + }).join(); + } + + /** + * Builds a GetPromptResult from the local template and provided arguments. + */ + private GetPromptResult buildLocalPromptResult(Document arguments) { + if (promptTemplate == null) { + return GetPromptResult.builder() + .description(promptInfo.getDescription()) + .messages(List.of( + PromptMessage.builder() + .role(PromptRole.ASSISTANT.getValue()) + .content(PromptMessageContent.builder() + .type(PromptMessageContentType.TEXT) + .text("Template is required for the prompt:" + promptInfo.getName()) + .build()) + .build())) + .build(); + } + + var requiredArguments = getRequiredArguments(); + + if (!requiredArguments.isEmpty() && arguments == null) { + return GetPromptResult.builder() + .description(promptInfo.getDescription()) + .messages(List.of(PromptMessage.builder() + .role(PromptRole.USER.getValue()) + .content(PromptMessageContent.builder() + .type(PromptMessageContentType.TEXT) + .text("Tell user that there are missing arguments for the prompt : " + + requiredArguments) + .build()) + .build())) + .build(); + } + + String processedText = applyTemplateArguments(promptTemplate, arguments); + + return GetPromptResult.builder() + .description(promptInfo.getDescription()) + .messages(List.of( + PromptMessage.builder() + .role(PromptRole.USER.getValue()) + .content(PromptMessageContent.builder() + .type(PromptMessageContentType.TEXT) + .text(processedText) + .build()) + .build())) + .build(); + } + + /** + * Applies template arguments to a template string. + * + * @param template The template string containing {{placeholder}} patterns + * @param arguments Document containing replacement values + * @return The template with all placeholders replaced + */ + private String applyTemplateArguments(String template, Document arguments) { + // Common cases + if (template == null || arguments == null || template.isEmpty()) { + return template; + } + + // Avoid any regex work if there are no potential placeholders + int firstBrace = template.indexOf("{{"); + if (firstBrace == -1) { + return template; + } + + Matcher matcher = PROMPT_ARGUMENT_PLACEHOLDER.matcher(template); + + int matchCount = 0; + int estimatedResultLength = template.length(); + Map replacementCache = new HashMap<>(); + + while (matcher.find()) { + matchCount++; + String argName = matcher.group(1); + + // Only look up each unique argument once + if (!replacementCache.containsKey(argName)) { + Document argValue = arguments.getMember(argName); + String replacement = (argValue != null) ? argValue.asString() : ""; + replacementCache.put(argName, replacement); + + // Adjust estimated length (subtract placeholder length, add replacement length) + estimatedResultLength = estimatedResultLength - matcher.group(0).length() + replacement.length(); + } + } + + // If no matches found, return original template + if (matchCount == 0) { + return template; + } + + // Reset matcher for the actual replacement pass + matcher.reset(); + + StringBuilder result = new StringBuilder(estimatedResultLength); + + // Single-pass replacement using cached values + while (matcher.find()) { + String argName = matcher.group(1); + String replacement = replacementCache.get(argName); + matcher.appendReplacement(result, Matcher.quoteReplacement(replacement)); + } + + matcher.appendTail(result); + + return result.toString(); + } + + /** + * Extracts the set of required argument names from the PromptInfo. + */ + private Set getRequiredArguments() { + return promptInfo.getArguments() + .stream() + .filter(PromptArgument::isRequired) + .map(PromptArgument::getName) + .collect(Collectors.toSet()); + } +} diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/PromptProcessor.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/PromptProcessor.java deleted file mode 100644 index a0c28baac..000000000 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/PromptProcessor.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.mcp.server; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import software.amazon.smithy.java.core.serde.document.Document; -import software.amazon.smithy.java.mcp.model.GetPromptResult; -import software.amazon.smithy.java.mcp.model.PromptArgument; -import software.amazon.smithy.java.mcp.model.PromptInfo; -import software.amazon.smithy.java.mcp.model.PromptMessage; -import software.amazon.smithy.java.mcp.model.PromptMessageContent; -import software.amazon.smithy.java.mcp.model.PromptMessageContentType; -import software.amazon.smithy.java.mcp.model.PromptRole; -import software.amazon.smithy.utils.SmithyUnstableApi; - -/** - * Handles processing of prompt templates and building prompt results. - */ -@SmithyUnstableApi -public final class PromptProcessor { - - private static final Pattern PROMPT_ARGUMENT_PLACEHOLDER = Pattern.compile("\\{\\{(\\w+)\\}\\}"); - - /** - * Builds a GetPromptResult from a PromptInfo and provided arguments. - * - * @param prompt The prompt information containing template and metadata - * @param arguments Document containing argument values for template substitution - * @return GetPromptResult with processed template or error messages - */ - public GetPromptResult buildPromptResult(Prompt prompt, Document arguments) { - String template = prompt.promptTemplate(); - if (template == null) { - return GetPromptResult.builder() - .description(prompt.promptInfo().getDescription()) - .messages(List.of( - PromptMessage.builder() - .role(PromptRole.ASSISTANT.getValue()) - .content(PromptMessageContent.builder() - .type(PromptMessageContentType.TEXT) - .text("Template is required for the prompt:" - + prompt.promptInfo().getName()) - .build()) - .build())) - .build(); - } - - var requiredArguments = getRequiredArguments(prompt.promptInfo()); - - if (!requiredArguments.isEmpty() && arguments == null) { - return GetPromptResult.builder() - .description(prompt.promptInfo().getDescription()) - .messages(List.of(PromptMessage.builder() - .role(PromptRole.USER.getValue()) - .content(PromptMessageContent.builder() - .type(PromptMessageContentType.TEXT) - .text("Tell user that there are missing arguments for the prompt : " - + requiredArguments) - .build()) - .build())) - .build(); - } - - String processedText = applyTemplateArguments(template, arguments); - - return GetPromptResult.builder() - .description(prompt.promptInfo().getDescription()) - .messages(List.of( - PromptMessage.builder() - .role(PromptRole.USER.getValue()) - .content(PromptMessageContent.builder() - .type(PromptMessageContentType.TEXT) - .text(processedText) - .build()) - .build())) - .build(); - } - - /** - * Applies template arguments to a template string. - * - * //TODO: Optimize it with indexes where the replacements need to be done. - * @param template The template string containing {{placeholder}} patterns - * @param arguments Document containing replacement values - * @return The template with all placeholders replaced - */ - public String applyTemplateArguments(String template, Document arguments) { - // Common cases - if (template == null || arguments == null || template.isEmpty()) { - return template; - } - - // Avoid any regex work if there are no potential placeholders - int firstBrace = template.indexOf("{{"); - if (firstBrace == -1) { - return template; - } - - Matcher matcher = PROMPT_ARGUMENT_PLACEHOLDER.matcher(template); - - int matchCount = 0; - int estimatedResultLength = template.length(); - Map replacementCache = new HashMap<>(); - - while (matcher.find()) { - matchCount++; - String argName = matcher.group(1); - - // Only look up each unique argument once - if (!replacementCache.containsKey(argName)) { - Document argValue = arguments.getMember(argName); - String replacement = (argValue != null) ? argValue.asString() : ""; - replacementCache.put(argName, replacement); - - // Adjust estimated length (subtract placeholder length, add replacement length) - estimatedResultLength = estimatedResultLength - matcher.group(0).length() + replacement.length(); - } - } - - // If no matches found, return original template - if (matchCount == 0) { - return template; - } - - // Reset matcher for the actual replacement pass - matcher.reset(); - - StringBuilder result = new StringBuilder(estimatedResultLength); - - // Single-pass replacement using cached values - while (matcher.find()) { - String argName = matcher.group(1); - String replacement = replacementCache.get(argName); - matcher.appendReplacement(result, Matcher.quoteReplacement(replacement)); - } - - matcher.appendTail(result); - - return result.toString(); - } - - /** - * Extracts the set of required argument names from a PromptInfo. - * - * @param promptInfo The prompt information to analyze - * @return Set of required argument names - */ - private Set getRequiredArguments(PromptInfo promptInfo) { - return promptInfo.getArguments() - .stream() - .filter(PromptArgument::isRequired) - .map(PromptArgument::getName) - .collect(Collectors.toSet()); - } -} diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptProcessorTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptProcessorTest.java deleted file mode 100644 index a4cf6a013..000000000 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptProcessorTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.mcp.server; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; - -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.Test; -import software.amazon.smithy.java.core.serde.document.Document; -import software.amazon.smithy.java.mcp.model.PromptArgument; -import software.amazon.smithy.java.mcp.model.PromptInfo; - -public class PromptProcessorTest { - - private final PromptProcessor processor = new PromptProcessor(); - - @Test - public void testApplyTemplateArgumentsWithSimpleSubstitution() { - String template = "Hello {{name}}!"; - Document arguments = Document.of(Map.of("name", Document.of("World"))); - - String result = processor.applyTemplateArguments(template, arguments); - - assertEquals("Hello World!", result); - } - - @Test - public void testApplyTemplateArgumentsWithMultipleSubstitutions() { - String template = "{{greeting}} {{name}}, welcome to {{place}}!"; - Document arguments = Document.of(Map.of( - "greeting", - Document.of("Hello"), - "name", - Document.of("X"), - "place", - Document.of("P"))); - - String result = processor.applyTemplateArguments(template, arguments); - - assertEquals("Hello X, welcome to P!", result); - } - - @Test - public void testApplyTemplateArgumentsWithMissingArgument() { - String template = "Hello {{name}}!"; - Document arguments = Document.of(Map.of("other", Document.of("value"))); - - String result = processor.applyTemplateArguments(template, arguments); - - assertEquals("Hello !", result); - } - - @Test - public void testApplyTemplateArgumentsWithNoPlaceholders() { - String template = "Hello World!"; - Document arguments = Document.of(Map.of("name", Document.of("John"))); - - String result = processor.applyTemplateArguments(template, arguments); - - assertEquals("Hello World!", result); - } - - @Test - public void testApplyTemplateArgumentsWithNullArguments() { - String template = "Hello {{name}}!"; - - String result = processor.applyTemplateArguments(template, null); - - assertEquals("Hello {{name}}!", result); - } - - @Test - public void testBuildPromptResultWithValidTemplate() { - PromptInfo promptInfo = PromptInfo.builder() - .name("test-prompt") - .description("A test prompt") - .arguments(List.of()) - .build(); - - Prompt prompt = new Prompt(promptInfo, "Hello {{name}}!"); - - Document arguments = Document.of(Map.of("name", Document.of("World"))); - - var result = processor.buildPromptResult(prompt, arguments); - - assertNotNull(result); - assertEquals("A test prompt", result.getDescription()); - assertEquals(1, result.getMessages().size()); - assertEquals("Hello World!", result.getMessages().get(0).getContent().getText()); - } - - @Test - public void testBuildPromptResultWithNullTemplate() { - PromptInfo promptInfo = PromptInfo.builder() - .name("test-prompt") - .description("A test prompt") - .arguments(List.of()) - .build(); - - Prompt prompt = new Prompt(promptInfo, null); - - var result = processor.buildPromptResult(prompt, null); - - assertNotNull(result); - assertEquals("A test prompt", result.getDescription()); - assertEquals(1, result.getMessages().size()); - assertEquals("Template is required for the prompt:test-prompt", - result.getMessages().get(0).getContent().getText()); - } - - @Test - public void testBuildPromptResultWithMissingRequiredArguments() { - PromptArgument requiredArg = PromptArgument.builder() - .name("name") - .description("The name") - .required(true) - .build(); - - PromptInfo promptInfo = PromptInfo.builder() - .name("test-prompt") - .description("A test prompt") - .arguments(List.of(requiredArg)) - .build(); - - Prompt prompt = new Prompt(promptInfo, "Hello {{name}}!"); - - var result = processor.buildPromptResult(prompt, null); - - assertNotNull(result); - assertEquals("A test prompt", result.getDescription()); - assertEquals(1, result.getMessages().size()); - assertEquals("Tell user that there are missing arguments for the prompt : [name]", - result.getMessages().get(0).getContent().getText()); - } -} diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptTest.java new file mode 100644 index 000000000..372ccec18 --- /dev/null +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/PromptTest.java @@ -0,0 +1,175 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.mcp.server; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.java.core.serde.document.Document; +import software.amazon.smithy.java.mcp.model.PromptArgument; +import software.amazon.smithy.java.mcp.model.PromptInfo; + +public class PromptTest { + + @Test + public void testGetPromptResultWithSimpleSubstitution() { + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of()) + .build(); + + Prompt prompt = new Prompt(promptInfo, "Hello {{name}}!"); + Document arguments = Document.of(Map.of("name", Document.of("World"))); + + var result = prompt.getPromptResult(arguments, null); + + assertNotNull(result); + assertEquals("Hello World!", result.getMessages().get(0).getContent().getText()); + } + + @Test + public void testGetPromptResultWithMultipleSubstitutions() { + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of()) + .build(); + + Prompt prompt = new Prompt(promptInfo, "{{greeting}} {{name}}, welcome to {{place}}!"); + Document arguments = Document.of(Map.of( + "greeting", + Document.of("Hello"), + "name", + Document.of("X"), + "place", + Document.of("P"))); + + var result = prompt.getPromptResult(arguments, null); + + assertNotNull(result); + assertEquals("Hello X, welcome to P!", result.getMessages().get(0).getContent().getText()); + } + + @Test + public void testGetPromptResultWithMissingArgument() { + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of()) + .build(); + + Prompt prompt = new Prompt(promptInfo, "Hello {{name}}!"); + Document arguments = Document.of(Map.of("other", Document.of("value"))); + + var result = prompt.getPromptResult(arguments, null); + + assertNotNull(result); + assertEquals("Hello !", result.getMessages().get(0).getContent().getText()); + } + + @Test + public void testGetPromptResultWithNoPlaceholders() { + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of()) + .build(); + + Prompt prompt = new Prompt(promptInfo, "Hello World!"); + Document arguments = Document.of(Map.of("name", Document.of("John"))); + + var result = prompt.getPromptResult(arguments, null); + + assertNotNull(result); + assertEquals("Hello World!", result.getMessages().get(0).getContent().getText()); + } + + @Test + public void testGetPromptResultWithNullArguments() { + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of()) + .build(); + + Prompt prompt = new Prompt(promptInfo, "Hello {{name}}!"); + + var result = prompt.getPromptResult(null, null); + + assertNotNull(result); + // When arguments is null and there are no required arguments, + // template placeholders remain unchanged + assertEquals("Hello {{name}}!", result.getMessages().get(0).getContent().getText()); + } + + @Test + public void testGetPromptResultWithValidTemplate() { + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of()) + .build(); + + Prompt prompt = new Prompt(promptInfo, "Hello {{name}}!"); + + Document arguments = Document.of(Map.of("name", Document.of("World"))); + + var result = prompt.getPromptResult(arguments, null); + + assertNotNull(result); + assertEquals("A test prompt", result.getDescription()); + assertEquals(1, result.getMessages().size()); + assertEquals("Hello World!", result.getMessages().get(0).getContent().getText()); + } + + @Test + public void testGetPromptResultWithNullTemplate() { + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of()) + .build(); + + Prompt prompt = new Prompt(promptInfo, (String) null); + + var result = prompt.getPromptResult(null, null); + + assertNotNull(result); + assertEquals("A test prompt", result.getDescription()); + assertEquals(1, result.getMessages().size()); + assertEquals("Template is required for the prompt:test-prompt", + result.getMessages().get(0).getContent().getText()); + } + + @Test + public void testGetPromptResultWithMissingRequiredArguments() { + PromptArgument requiredArg = PromptArgument.builder() + .name("name") + .description("The name") + .required(true) + .build(); + + PromptInfo promptInfo = PromptInfo.builder() + .name("test-prompt") + .description("A test prompt") + .arguments(List.of(requiredArg)) + .build(); + + Prompt prompt = new Prompt(promptInfo, "Hello {{name}}!"); + + var result = prompt.getPromptResult(null, null); + + assertNotNull(result); + assertEquals("A test prompt", result.getDescription()); + assertEquals(1, result.getMessages().size()); + assertEquals("Tell user that there are missing arguments for the prompt : [name]", + result.getMessages().get(0).getContent().getText()); + } +}