Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/google/adk/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional
from typing import Union

from google.adk.models.base_llm import BaseLlm
from google.genai import types as genai_types
from pydantic import alias_generators
from pydantic import BaseModel
Expand Down Expand Up @@ -75,10 +76,11 @@ class PrebuiltMetrics(Enum):
class JudgeModelOptions(EvalBaseModel):
"""Options for an eval metric's judge model."""

judge_model: str = Field(
judge_model: Union[str, BaseLlm] = Field(
default="gemini-2.5-flash",
description=(
"The judge model to use for evaluation. It can be a model name."
"The judge model to use for evaluation. It can be a model name"
" string or a BaseLlm instance for custom/self-hosted models."
),
)

Expand Down
13 changes: 9 additions & 4 deletions src/google/adk/evaluation/hallucinations_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,22 @@ def __init__(self, eval_metric: EvalMetric):
self._judge_model = self._setup_auto_rater()
self.segmenter_prompt = _HALLUCINATIONS_V1_SEGMENTER_PROMPT
self.sentence_validator_prompt = _HALLUCINATIONS_V1_VALIDATOR_PROMPT
self._model = self._judge_model_options.judge_model
judge_model = self._judge_model_options.judge_model
self._model = (
judge_model if isinstance(judge_model, str) else judge_model.model
)
self._model_config = (
self._judge_model_options.judge_model_config
or genai_types.GenerateContentConfig()
)

def _setup_auto_rater(self) -> BaseLlm:
model_id = self._judge_model_options.judge_model
judge_model = self._judge_model_options.judge_model
if isinstance(judge_model, BaseLlm):
return judge_model
llm_registry = LLMRegistry()
llm_class = llm_registry.resolve(model_id)
return llm_class(model=model_id)
llm_class = llm_registry.resolve(judge_model)
return llm_class(model=judge_model)

def _create_context_for_step(
self,
Expand Down
14 changes: 10 additions & 4 deletions src/google/adk/evaluation/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ async def evaluate_invocations(
per_invocation_results = []
for actual, expected in zip(actual_invocations, expected_invocations):
auto_rater_prompt = self.format_auto_rater_prompt(actual, expected)
judge_model = self._judge_model_options.judge_model
model_str = (
judge_model if isinstance(judge_model, str) else judge_model.model
)
llm_request = LlmRequest(
model=self._judge_model_options.judge_model,
model=model_str,
contents=[
genai_types.Content(
parts=[genai_types.Part(text=auto_rater_prompt)],
Expand Down Expand Up @@ -181,7 +185,9 @@ async def evaluate_invocations(
return EvaluationResult()

def _setup_auto_rater(self) -> BaseLlm:
model_id = self._judge_model_options.judge_model
judge_model = self._judge_model_options.judge_model
if isinstance(judge_model, BaseLlm):
return judge_model
llm_registry = LLMRegistry()
llm_class = llm_registry.resolve(model_id)
return llm_class(model=model_id)
llm_class = llm_registry.resolve(judge_model)
return llm_class(model=judge_model)
24 changes: 18 additions & 6 deletions src/google/adk/evaluation/simulation/llm_backed_user_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import logging
from typing import ClassVar
from typing import Optional
from typing import Union

from google.genai import types as genai_types
from pydantic import Field
from pydantic import field_validator
from typing_extensions import override

from ...events.event import Event
from ...models.base_llm import BaseLlm
from ...models.llm_request import LlmRequest
from ...models.registry import LLMRegistry
from ...utils.context_utils import Aclosing
Expand All @@ -47,9 +49,12 @@
class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig):
"""Contains configurations required by an LLM backed user simulator."""

model: str = Field(
model: Union[str, BaseLlm] = Field(
default="gemini-2.5-flash",
description="The model to use for user simulation.",
description=(
"The model to use for user simulation. It can be a model name"
" string or a BaseLlm instance for custom/self-hosted models."
),
)

model_configuration: genai_types.GenerateContentConfig = Field(
Expand Down Expand Up @@ -124,9 +129,12 @@ def __init__(
super().__init__(config, config_type=LlmBackedUserSimulator.config_type)
self._conversation_scenario = conversation_scenario
self._invocation_count = 0
llm_registry = LLMRegistry()
llm_class = llm_registry.resolve(self._config.model)
self._llm = llm_class(model=self._config.model)
if isinstance(self._config.model, BaseLlm):
self._llm = self._config.model
else:
llm_registry = LLMRegistry()
llm_class = llm_registry.resolve(self._config.model)
self._llm = llm_class(model=self._config.model)
self._user_persona = self._conversation_scenario.user_persona

@classmethod
Expand Down Expand Up @@ -172,8 +180,12 @@ async def _get_llm_response(
user_persona=self._user_persona,
)

config_model = self._config.model
model_str = (
config_model if isinstance(config_model, str) else config_model.model
)
llm_request = LlmRequest(
model=self._config.model,
model=model_str,
config=self._config.model_configuration,
contents=[
genai_types.Content(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,12 @@ async def evaluate_invocations(
return self._aggregate_conversation_results(results)

def _setup_llm(self) -> BaseLlm:
model_id = self._llm_options.judge_model
judge_model = self._llm_options.judge_model
if isinstance(judge_model, BaseLlm):
return judge_model
llm_registry = LLMRegistry()
llm_class = llm_registry.resolve(model_id)
return llm_class(model=model_id)
llm_class = llm_registry.resolve(judge_model)
return llm_class(model=judge_model)

def _format_llm_prompt(
self,
Expand Down Expand Up @@ -325,8 +327,12 @@ async def _evaluate_intermediate_turn(
previous_invocations=invocation_history,
)

judge_model = self._llm_options.judge_model
model_str = (
judge_model if isinstance(judge_model, str) else judge_model.model
)
llm_request = LlmRequest(
model=self._llm_options.judge_model,
model=model_str,
contents=[
genai_types.Content(
parts=[genai_types.Part(text=auto_rater_prompt)],
Expand Down