Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,11 @@ private List<Part> toParts(AiMessage aiMessage) {
});
return parts;
} else {
Part part = Part.builder().text(aiMessage.text()).build();
return List.of(part);
String text = aiMessage.text();
if (text == null) {
return List.of();
}
return List.of(Part.builder().text(text).build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,263 +711,26 @@ void testGenerateContentWithStructuredResponseJsonSchema() {
}

@Test
@DisplayName("Should handle MCP tools with parametersJsonSchema")
void testGenerateContentWithMcpToolParametersJsonSchema() {
@DisplayName("Should handle null AiMessage text without throwing NPE")
void testGenerateContentWithNullAiMessageText() {
// Given
// Create a mock BaseTool for MCP tool
final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class);
when(mcpTool.name()).thenReturn("mcpTool");
when(mcpTool.description()).thenReturn("An MCP tool");

// Create a mock FunctionDeclaration
final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class);
when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration));

// MCP tools use parametersJsonSchema() instead of parameters()
// Create a JSON schema object (Map representation)
final Map<String, Object> jsonSchemaMap =
Map.of(
"type",
"object",
"properties",
Map.of("city", Map.of("type", "string", "description", "City name")),
"required",
List.of("city"));

// Mock parametersJsonSchema() to return the JSON schema object
when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(jsonSchemaMap));
when(functionDeclaration.parameters()).thenReturn(Optional.empty());

// Create a LlmRequest with the MCP tool
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool"))))
.tools(Map.of("mcpTool", mcpTool))
.build();

// Mock the AI response
final AiMessage aiMessage = AiMessage.from("Tool executed successfully");

final ChatResponse chatResponse = mock(ChatResponse.class);
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Tool executed successfully");

// Verify the request was built correctly with the tool specification
final ArgumentCaptor<ChatRequest> requestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
verify(chatModel).chat(requestCaptor.capture());
final ChatRequest capturedRequest = requestCaptor.getValue();

// Verify tool specifications were created from parametersJsonSchema
assertThat(capturedRequest.toolSpecifications()).isNotEmpty();
assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool");
assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool");
}

@Test
@DisplayName("Should handle MCP tools with parametersJsonSchema when it's already a Schema")
void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() {
// Given
// Create a mock BaseTool for MCP tool
final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class);
when(mcpTool.name()).thenReturn("mcpTool");
when(mcpTool.description()).thenReturn("An MCP tool");

// Create a mock FunctionDeclaration
final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class);
when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration));

// Create a Schema object directly (when parametersJsonSchema returns Schema)
final Schema cityPropertySchema =
Schema.builder().type("STRING").description("City name").build();

final Schema objectSchema =
Schema.builder()
.type("OBJECT")
.properties(Map.of("city", cityPropertySchema))
.required(List.of("city"))
.build();

// Mock parametersJsonSchema() to return Schema directly
when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(objectSchema));
when(functionDeclaration.parameters()).thenReturn(Optional.empty());

// Create a LlmRequest with the MCP tool
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool"))))
.tools(Map.of("mcpTool", mcpTool))
.build();

// Mock the AI response
final AiMessage aiMessage = AiMessage.from("Tool executed successfully");

final ChatResponse chatResponse = mock(ChatResponse.class);
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Tool executed successfully");

// Verify the request was built correctly with the tool specification
final ArgumentCaptor<ChatRequest> requestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
verify(chatModel).chat(requestCaptor.capture());
final ChatRequest capturedRequest = requestCaptor.getValue();

// Verify tool specifications were created from parametersJsonSchema
assertThat(capturedRequest.toolSpecifications()).isNotEmpty();
assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool");
assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool");
}

@Test
@DisplayName(
"Should use TokenCountEstimator to estimate token usage when TokenUsage is not available")
void testTokenCountEstimatorFallback() {
// Given
// Create a mock TokenCountEstimator
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens

// Create LangChain4j with the TokenCountEstimator using Builder
final LangChain4j langChain4jWithEstimator =
LangChain4j.builder()
.chatModel(chatModel)
.modelName(MODEL_NAME)
.tokenCountEstimator(tokenCountEstimator)
.build();

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
.build();

// Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts)
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response has usage metadata estimated by TokenCountEstimator
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("The weather is sunny today.");

// IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator
assertThat(response.usageMetadata()).isPresent();
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20

// Verify the estimator was actually called
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
}

@Test
@DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided")
void testTokenCountEstimatorPriority() {
// Given
// Create a mock TokenCountEstimator
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator

// Create LangChain4j with the TokenCountEstimator using Builder
final LangChain4j langChain4jWithEstimator =
LangChain4j.builder()
.chatModel(chatModel)
.modelName(MODEL_NAME)
.tokenCountEstimator(tokenCountEstimator)
.build();

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
.build();

// Mock ChatResponse WITH actual TokenUsage from the LLM
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage
assertThat(response).isNotNull();
assertThat(response.usageMetadata()).isPresent();
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50

// Verify the estimator was called (it takes priority)
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
}

@Test
@DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided")
void testNoUsageMetadataWithoutEstimator() {
// Given
// Create LangChain4j WITHOUT TokenCountEstimator (default behavior)
final LangChain4j langChain4jNoEstimator =
LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build();

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("Hello, world!"))))
.build();
LlmRequest.builder().contents(List.of(Content.fromParts(Part.fromText("Hello")))).build();

// Mock ChatResponse WITHOUT TokenUsage
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?");
final AiMessage aiMessage = mock(AiMessage.class);
when(aiMessage.text()).thenReturn(null);
when(aiMessage.hasToolExecutionRequests()).thenReturn(false);
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst();
final Flowable<LlmResponse> responseFlowable = langChain4j.generateContent(llmRequest, false);
final LlmResponse response = responseFlowable.blockingFirst();

// Then
// Verify the response does NOT have usage metadata
// Then - no NPE thrown, and content has no text parts
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?");

// IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator
assertThat(response.usageMetadata()).isEmpty();
assertThat(response.content().get().parts().orElse(List.of())).isEmpty();
}
}
Loading