Skip to content
10 changes: 7 additions & 3 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..memory.base_memory_service import BaseMemoryService
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
from ..runners import Runner
from .adk_web_server import AdkWebServer
from .service_registry import load_services_module
Expand Down Expand Up @@ -79,6 +82,7 @@ def get_fast_api_app(
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
use_local_storage: bool = True,
memory_service: Optional[BaseMemoryService] = None,
eval_storage_uri: Optional[str] = None,
allow_origins: Optional[list[str]] = None,
web: bool,
Expand Down Expand Up @@ -161,13 +165,13 @@ def get_fast_api_app(
load_services_module(agents_dir)

# Build the Memory service
try:
if memory_service:
pass
else:
memory_service = create_memory_service_from_options(
base_dir=agents_dir,
memory_service_uri=memory_service_uri,
)
except ValueError as exc:
raise click.ClickException(str(exc)) from exc

# Build the Session service
session_service = create_session_service_from_options(
Expand Down
70 changes: 70 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,5 +1768,75 @@ async def run_async_session_not_found(self, **kwargs):
assert "Session not found" in response.json()["detail"]


def test_get_fast_api_app_with_custom_memory_service(
mock_session_service,
mock_artifact_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
):
"""Test that custom memory_service is used directly when provided."""
custom_memory_service = MagicMock()

with (
patch.object(signal, "signal", autospec=True, return_value=None),
patch.object(
fast_api_module,
"create_session_service_from_options",
autospec=True,
return_value=mock_session_service,
),
patch.object(
fast_api_module,
"create_artifact_service_from_options",
autospec=True,
return_value=mock_artifact_service,
),
patch.object(
fast_api_module,
"create_memory_service_from_options",
autospec=True,
) as mock_create_memory_service,
patch.object(
fast_api_module,
"AgentLoader",
autospec=True,
return_value=mock_agent_loader,
),
patch.object(
fast_api_module,
"LocalEvalSetsManager",
autospec=True,
return_value=mock_eval_sets_manager,
),
patch.object(
fast_api_module,
"LocalEvalSetResultsManager",
autospec=True,
return_value=mock_eval_set_results_manager,
),
patch.object(
fast_api_module,
"load_services_module",
autospec=True,
return_value=None,
),
):
app = get_fast_api_app(
agents_dir=".",
web=True,
session_service_uri="",
artifact_service_uri="",
memory_service_uri="",
memory_service=custom_memory_service,
allow_origins=["*"],
a2a=False,
host="127.0.0.1",
port=8000,
)

mock_create_memory_service.assert_not_called()


if __name__ == "__main__":
pytest.main(["-xvs", __file__])