diff --git a/pyproject.toml b/pyproject.toml index 11afcd82..1ee70936 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,9 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +firestore = [ + "google-cloud-firestore>=2.16.0", +] test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", diff --git a/src/google/adk_community/sessions/__init__.py b/src/google/adk_community/sessions/__init__.py index 90bf28d7..a4540bbc 100644 --- a/src/google/adk_community/sessions/__init__.py +++ b/src/google/adk_community/sessions/__init__.py @@ -14,6 +14,7 @@ """Community session services for ADK.""" +from .firestore_session_service import FirestoreSessionService from .redis_session_service import RedisSessionService -__all__ = ["RedisSessionService"] +__all__ = ["FirestoreSessionService", "RedisSessionService"] diff --git a/src/google/adk_community/sessions/firestore_session_service.py b/src/google/adk_community/sessions/firestore_session_service.py new file mode 100644 index 00000000..d86aaec6 --- /dev/null +++ b/src/google/adk_community/sessions/firestore_session_service.py @@ -0,0 +1,441 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firestore-backed session service for Google ADK. + +Provides persistent, serverless session storage using Google Cloud +Firestore. Well-suited for Cloud Run, Cloud Functions, or any GCP +environment where managing a SQL database is undesirable. + +Firestore collection layout:: + + {prefix}adk_app_states/{app_name} + {prefix}adk_user_states/{app_name}_{user_id} + {prefix}adk_sessions/{session_id} + -> subcollection: events/{event_id} + +Requires the ``google-cloud-firestore`` package:: + + pip install google-cloud-firestore +""" + +from __future__ import annotations + +import copy +import logging +import time +from typing import Any, Optional +import uuid + +from typing_extensions import override + +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import ( + BaseSessionService, + GetSessionConfig, + ListSessionsResponse, +) +from google.adk.sessions.session import Session +from google.adk.sessions.state import State + +logger = logging.getLogger("google_adk." + __name__) + +_APP_STATES_COLLECTION = "adk_app_states" +_USER_STATES_COLLECTION = "adk_user_states" +_SESSIONS_COLLECTION = "adk_sessions" +_EVENTS_SUBCOLLECTION = "events" + +_FIELD_APP_NAME = "app_name" +_FIELD_USER_ID = "user_id" +_FIELD_STATE = "state" +_FIELD_CREATE_TIME = "create_time" +_FIELD_UPDATE_TIME = "update_time" +_FIELD_EVENT_DATA = "event_data" +_FIELD_TIMESTAMP = "timestamp" +_FIELD_INVOCATION_ID = "invocation_id" + +_BATCH_DELETE_LIMIT = 500 + + +def _user_state_doc_id(app_name: str, user_id: str) -> str: + return f"{app_name}_{user_id}" + + +def _extract_state_delta( + state: Optional[dict[str, Any]], +) -> dict[str, dict[str, Any]]: + """Splits a state dict into app / user / session buckets.""" + deltas: dict[str, dict[str, Any]] = {"app": {}, "user": {}, "session": {}} + if not state: + return deltas + for key, value in state.items(): + if key.startswith(State.APP_PREFIX): + deltas["app"][key.removeprefix(State.APP_PREFIX)] = value + elif key.startswith(State.USER_PREFIX): + deltas["user"][key.removeprefix(State.USER_PREFIX)] = value + elif not key.startswith(State.TEMP_PREFIX): + deltas["session"][key] = value + return deltas + + +def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], +) -> dict[str, Any]: + """Combines app / user / session state into the flat dict ADK expects.""" + merged = copy.deepcopy(session_state) + for key, value in app_state.items(): + merged[State.APP_PREFIX + key] = value + for key, value in user_state.items(): + merged[State.USER_PREFIX + key] = value + return merged + + +class FirestoreSessionService(BaseSessionService): + """A session service backed by Google Cloud Firestore. + + Args: + project: GCP project ID. ``None`` uses Application Default + Credentials. + database: Firestore database ID. Defaults to ``"(default)"``. + collection_prefix: Optional prefix for all collection names (useful + for multi-tenant setups or test isolation). + """ + + def __init__( + self, + *, + project: Optional[str] = None, + database: str = "(default)", + collection_prefix: str = "", + ): + try: + from google.cloud.firestore_v1 import AsyncClient # noqa: F401 + except ImportError as e: + raise ImportError( + "FirestoreSessionService requires google-cloud-firestore. " + "Install it with: pip install google-cloud-firestore" + ) from e + + self._db: Any = AsyncClient(project=project, database=database) + self._prefix = collection_prefix + + # -- collection helpers -------------------------------------------------- + + def _col_app_states(self): + return self._db.collection(f"{self._prefix}{_APP_STATES_COLLECTION}") + + def _col_user_states(self): + return self._db.collection(f"{self._prefix}{_USER_STATES_COLLECTION}") + + def _col_sessions(self): + return self._db.collection(f"{self._prefix}{_SESSIONS_COLLECTION}") + + def _events_col(self, session_id: str): + return ( + self._col_sessions() + .document(session_id) + .collection(_EVENTS_SUBCOLLECTION) + ) + + # -- state helpers ------------------------------------------------------- + + async def _get_app_state(self, app_name: str) -> dict[str, Any]: + doc = await self._col_app_states().document(app_name).get() + if doc.exists: + return doc.to_dict().get(_FIELD_STATE, {}) + return {} + + async def _get_user_state( + self, app_name: str, user_id: str + ) -> dict[str, Any]: + doc_id = _user_state_doc_id(app_name, user_id) + doc = await self._col_user_states().document(doc_id).get() + if doc.exists: + return doc.to_dict().get(_FIELD_STATE, {}) + return {} + + async def _update_app_state_transactional( + self, app_name: str, delta: dict[str, Any] + ) -> dict[str, Any]: + """Atomically applies *delta* to app state inside a transaction.""" + doc_ref = self._col_app_states().document(app_name) + + @self._db.async_transactional + async def _txn(transaction): + snap = await doc_ref.get(transaction=transaction) + current = snap.to_dict().get(_FIELD_STATE, {}) if snap.exists else {} + current.update(delta) + transaction.set(doc_ref, {_FIELD_STATE: current}, merge=True) + return current + + transaction = self._db.transaction() + return await _txn(transaction) + + async def _update_user_state_transactional( + self, app_name: str, user_id: str, delta: dict[str, Any] + ) -> dict[str, Any]: + """Atomically applies *delta* to user state inside a transaction.""" + doc_id = _user_state_doc_id(app_name, user_id) + doc_ref = self._col_user_states().document(doc_id) + + @self._db.async_transactional + async def _txn(transaction): + snap = await doc_ref.get(transaction=transaction) + current = snap.to_dict().get(_FIELD_STATE, {}) if snap.exists else {} + current.update(delta) + transaction.set( + doc_ref, + { + _FIELD_APP_NAME: app_name, + _FIELD_USER_ID: user_id, + _FIELD_STATE: current, + }, + merge=True, + ) + return current + + transaction = self._db.transaction() + return await _txn(transaction) + + # -- CRUD ---------------------------------------------------------------- + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + + existing = await self._col_sessions().document(session_id).get() + if existing.exists: + raise ValueError( + f"Session with id {session_id} already exists." + ) + + deltas = _extract_state_delta(state) + app_state_delta = deltas["app"] + user_state_delta = deltas["user"] + session_state = deltas["session"] + + # Transactional state updates; reuse returned state to avoid re-read. + app_state = ( + await self._update_app_state_transactional(app_name, app_state_delta) + if app_state_delta + else await self._get_app_state(app_name) + ) + user_state = ( + await self._update_user_state_transactional( + app_name, user_id, user_state_delta + ) + if user_state_delta + else await self._get_user_state(app_name, user_id) + ) + + now = time.time() + await self._col_sessions().document(session_id).set({ + _FIELD_APP_NAME: app_name, + _FIELD_USER_ID: user_id, + _FIELD_STATE: session_state, + _FIELD_CREATE_TIME: now, + _FIELD_UPDATE_TIME: now, + }) + + merged = _merge_state(app_state, user_state, session_state) + return Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=merged, + last_update_time=now, + ) + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + doc = await self._col_sessions().document(session_id).get() + if not doc.exists: + return None + + data = doc.to_dict() + if data.get(_FIELD_APP_NAME) != app_name: + return None + if data.get(_FIELD_USER_ID) != user_id: + return None + + session_state = data.get(_FIELD_STATE, {}) + + # Build events query with server-side filtering. + events_query = self._events_col(session_id).order_by(_FIELD_TIMESTAMP) + + if config and config.after_timestamp: + events_query = events_query.where( + filter=self._db.field_filter( + _FIELD_TIMESTAMP, ">=", config.after_timestamp + ) + ) + + if config and config.num_recent_events: + events_query = events_query.limit_to_last( + config.num_recent_events + ) + + events: list[Event] = [] + async for event_doc in events_query.stream(): + raw = event_doc.to_dict().get(_FIELD_EVENT_DATA, {}) + if raw: + events.append(Event.model_validate(raw)) + + app_state = await self._get_app_state(app_name) + user_state = await self._get_user_state(app_name, user_id) + merged = _merge_state(app_state, user_state, session_state) + + return Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=merged, + events=events, + last_update_time=data.get(_FIELD_UPDATE_TIME, 0.0), + ) + + @override + async def list_sessions( + self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + query = self._col_sessions().where( + filter=self._db.field_filter(_FIELD_APP_NAME, "==", app_name) + ) + query = query.where( + filter=self._db.field_filter(_FIELD_USER_ID, "==", user_id) + ) + + # Fetch shared state once, outside the loop. + app_state = await self._get_app_state(app_name) + user_state = await self._get_user_state(app_name, user_id) + + sessions: list[Session] = [] + async for doc in query.stream(): + data = doc.to_dict() + session_state = data.get(_FIELD_STATE, {}) + merged = _merge_state(app_state, user_state, session_state) + sessions.append( + Session( + app_name=app_name, + user_id=data.get(_FIELD_USER_ID, ""), + id=doc.id, + state=merged, + last_update_time=data.get(_FIELD_UPDATE_TIME, 0.0), + ) + ) + + return ListSessionsResponse(sessions=sessions) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + session_ref = self._col_sessions().document(session_id) + doc = await session_ref.get() + if not doc.exists: + return + + # Batch-delete events in chunks of _BATCH_DELETE_LIMIT. + events_ref = session_ref.collection(_EVENTS_SUBCOLLECTION) + batch = self._db.batch() + count = 0 + async for event_doc in events_ref.stream(): + batch.delete(event_doc.reference) + count += 1 + if count >= _BATCH_DELETE_LIMIT: + await batch.commit() + batch = self._db.batch() + count = 0 + if count: + await batch.commit() + + await session_ref.delete() + + @override + async def append_event( + self, session: Session, event: Event + ) -> Event: + if event.partial: + return event + + app_name = session.app_name + user_id = session.user_id + session_id = session.id + + session_ref = self._col_sessions().document(session_id) + doc = await session_ref.get() + if not doc.exists: + logger.warning( + "Cannot append event: session %s not found.", session_id + ) + return event + + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + + if event.actions and event.actions.state_delta: + deltas = _extract_state_delta(event.actions.state_delta) + + if deltas["app"]: + await self._update_app_state_transactional( + app_name, deltas["app"] + ) + if deltas["user"]: + await self._update_user_state_transactional( + app_name, user_id, deltas["user"] + ) + if deltas["session"]: + stored_state = doc.to_dict().get(_FIELD_STATE, {}) + stored_state.update(deltas["session"]) + await session_ref.update({_FIELD_STATE: stored_state}) + + event_data = event.model_dump(exclude_none=True, mode="json") + await self._events_col(session_id).document(event.id).set({ + _FIELD_EVENT_DATA: event_data, + _FIELD_TIMESTAMP: event.timestamp, + _FIELD_INVOCATION_ID: event.invocation_id, + }) + + await session_ref.update({_FIELD_UPDATE_TIME: event.timestamp}) + return event + + async def close(self) -> None: + """Closes the underlying Firestore client.""" + self._db.close() + + async def __aenter__(self) -> FirestoreSessionService: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py new file mode 100644 index 00000000..a5d582eb --- /dev/null +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -0,0 +1,529 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FirestoreSessionService. + +All Firestore interactions are mocked in-memory — no GCP project needed. +""" + +from __future__ import annotations + +import copy +import time +from typing import Any, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import GetSessionConfig +from google.genai import types + + +# --------------------------------------------------------------------------- +# Lightweight in-memory Firestore mock +# --------------------------------------------------------------------------- + +class _FakeDocSnapshot: + """Mimics a Firestore DocumentSnapshot.""" + + def __init__(self, doc_id: str, data: Optional[dict] = None): + self.id = doc_id + self._data = data + self.exists = data is not None + self.reference = MagicMock() + self.reference.delete = AsyncMock() + + def to_dict(self) -> dict: + return copy.deepcopy(self._data) if self._data else {} + + +class _FakeDocRef: + """Mimics a Firestore AsyncDocumentReference.""" + + def __init__(self, store: dict, doc_id: str, parent_path: str = ""): + self._store = store + self._id = doc_id + self._path = f"{parent_path}/{doc_id}" if parent_path else doc_id + + async def get(self, transaction=None) -> _FakeDocSnapshot: + data = self._store.get(self._path) + return _FakeDocSnapshot(self._id, copy.deepcopy(data)) + + async def set(self, data: dict, merge: bool = False) -> None: + if merge and self._path in self._store: + existing = self._store[self._path] + existing.update(data) + else: + self._store[self._path] = copy.deepcopy(data) + + async def update(self, data: dict) -> None: + if self._path in self._store: + self._store[self._path].update(copy.deepcopy(data)) + + async def delete(self) -> None: + self._store.pop(self._path, None) + + def collection(self, name: str): + return _FakeCollection(self._store, f"{self._path}/{name}") + + +class _FakeQuery: + """Mimics a Firestore query with where / order_by / limit_to_last.""" + + def __init__(self, docs: list[_FakeDocSnapshot]): + self._docs = docs + self._filters: list[tuple[str, str, Any]] = [] + self._order_field: Optional[str] = None + self._limit_last: Optional[int] = None + + def where(self, *, filter) -> _FakeQuery: + self._filters.append(filter) + return self + + def order_by(self, field: str) -> _FakeQuery: + self._order_field = field + return self + + def limit_to_last(self, n: int) -> _FakeQuery: + self._limit_last = n + return self + + async def stream(self): + results = list(self._docs) + for field, op, value in self._filters: + filtered = [] + for doc in results: + d = doc.to_dict() + v = d.get(field) + if op == "==" and v == value: + filtered.append(doc) + elif op == ">=" and v is not None and v >= value: + filtered.append(doc) + results = filtered + + if self._order_field: + results.sort( + key=lambda d: d.to_dict().get(self._order_field, 0) + ) + + if self._limit_last is not None: + results = results[-self._limit_last:] + + for doc in results: + yield doc + + +class _FakeCollection: + """Mimics a Firestore AsyncCollectionReference.""" + + def __init__(self, store: dict, path: str): + self._store = store + self._path = path + + def document(self, doc_id: str) -> _FakeDocRef: + return _FakeDocRef(self._store, doc_id, self._path) + + def where(self, *, filter) -> _FakeQuery: + docs = self._snapshot_docs() + q = _FakeQuery(docs) + q.where(filter=filter) + return q + + def order_by(self, field: str) -> _FakeQuery: + docs = self._snapshot_docs() + q = _FakeQuery(docs) + q.order_by(field) + return q + + def _snapshot_docs(self) -> list[_FakeDocSnapshot]: + prefix = self._path + "/" + docs = [] + for key, data in self._store.items(): + if key.startswith(prefix): + suffix = key[len(prefix):] + if "/" not in suffix: + docs.append(_FakeDocSnapshot(suffix, copy.deepcopy(data))) + return docs + + async def stream(self): + for doc in self._snapshot_docs(): + yield doc + + +class _FakeBatch: + """Mimics a Firestore WriteBatch.""" + + def __init__(self): + self._ops: list[tuple[str, Any]] = [] + + def delete(self, ref): + self._ops.append(("delete", ref)) + + async def commit(self): + for op_type, ref in self._ops: + if op_type == "delete": + await ref.delete() + self._ops.clear() + + +class _FakeTransaction: + """Mimics a Firestore async transaction.""" + + def __init__(self): + self._writes: list[tuple] = [] + + def set(self, ref, data, merge=False): + self._writes.append(("set", ref, data, merge)) + + def update(self, ref, data): + self._writes.append(("update", ref, data)) + + +class _FakeClient: + """Mimics the Firestore AsyncClient.""" + + def __init__(self): + self._store: dict[str, dict] = {} + + def collection(self, path: str): + return _FakeCollection(self._store, path) + + def transaction(self): + return _FakeTransaction() + + @staticmethod + def field_filter(field, op, value): + return (field, op, value) + + def async_transactional(self, fn): + """Wraps *fn* so it executes normally then applies writes.""" + + async def wrapper(transaction): + result = await fn(transaction) + for op_type, ref, data, *rest in transaction._writes: + if op_type == "set": + merge = rest[0] if rest else False + await ref.set(data, merge=merge) + elif op_type == "update": + await ref.update(data) + return result + + return wrapper + + def batch(self): + return _FakeBatch() + + def close(self): + pass + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest_asyncio.fixture +async def service(): + """Creates a FirestoreSessionService with a fake in-memory backend.""" + with patch( + "google.adk_community.sessions.firestore_session_service." + "FirestoreSessionService.__init__", + lambda self, **kw: None, + ): + from google.adk_community.sessions.firestore_session_service import ( + FirestoreSessionService, + ) + + svc = FirestoreSessionService() + svc._db = _FakeClient() + svc._prefix = "" + yield svc + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestFirestoreSessionService: + + @pytest.mark.asyncio + async def test_create_session(self, service): + session = await service.create_session( + app_name="app1", user_id="u1", state={"key": "val"} + ) + assert session.app_name == "app1" + assert session.user_id == "u1" + assert session.id is not None + assert session.state["key"] == "val" + assert session.last_update_time > 0 + + @pytest.mark.asyncio + async def test_create_session_with_custom_id(self, service): + session = await service.create_session( + app_name="app1", + user_id="u1", + session_id="custom-id", + ) + assert session.id == "custom-id" + + @pytest.mark.asyncio + async def test_create_duplicate_session_raises(self, service): + await service.create_session( + app_name="app1", user_id="u1", session_id="dup" + ) + with pytest.raises(ValueError, match="already exists"): + await service.create_session( + app_name="app1", user_id="u1", session_id="dup" + ) + + @pytest.mark.asyncio + async def test_get_session(self, service): + created = await service.create_session( + app_name="app1", user_id="u1", state={"k": "v"} + ) + fetched = await service.get_session( + app_name="app1", user_id="u1", session_id=created.id + ) + assert fetched is not None + assert fetched.id == created.id + assert fetched.state["k"] == "v" + + @pytest.mark.asyncio + async def test_get_nonexistent_session(self, service): + result = await service.get_session( + app_name="app1", user_id="u1", session_id="nope" + ) + assert result is None + + @pytest.mark.asyncio + async def test_get_session_wrong_app(self, service): + created = await service.create_session( + app_name="app1", user_id="u1" + ) + result = await service.get_session( + app_name="wrong_app", user_id="u1", session_id=created.id + ) + assert result is None + + @pytest.mark.asyncio + async def test_list_sessions(self, service): + for i in range(3): + await service.create_session( + app_name="app1", + user_id="u1", + session_id=f"s{i}", + ) + resp = await service.list_sessions(app_name="app1", user_id="u1") + assert len(resp.sessions) == 3 + ids = {s.id for s in resp.sessions} + assert ids == {"s0", "s1", "s2"} + + @pytest.mark.asyncio + async def test_delete_session(self, service): + session = await service.create_session( + app_name="app1", user_id="u1", session_id="del-me" + ) + await service.delete_session( + app_name="app1", user_id="u1", session_id="del-me" + ) + result = await service.get_session( + app_name="app1", user_id="u1", session_id="del-me" + ) + assert result is None + + @pytest.mark.asyncio + async def test_delete_nonexistent_session(self, service): + # Should not raise. + await service.delete_session( + app_name="app1", user_id="u1", session_id="ghost" + ) + + @pytest.mark.asyncio + async def test_append_event(self, service): + session = await service.create_session( + app_name="app1", user_id="u1", session_id="ev-test" + ) + event = Event( + invocation_id="inv1", + author="user", + content=types.Content( + role="user", parts=[types.Part(text="hello")] + ), + ) + returned = await service.append_event(session=session, event=event) + assert returned.id == event.id + + fetched = await service.get_session( + app_name="app1", user_id="u1", session_id="ev-test" + ) + assert len(fetched.events) == 1 + assert fetched.events[0].content.parts[0].text == "hello" + + @pytest.mark.asyncio + async def test_append_event_partial_skipped(self, service): + session = await service.create_session( + app_name="app1", user_id="u1", session_id="partial-test" + ) + event = Event(author="user", partial=True) + result = await service.append_event(session=session, event=event) + assert result is event + + fetched = await service.get_session( + app_name="app1", user_id="u1", session_id="partial-test" + ) + assert len(fetched.events) == 0 + + @pytest.mark.asyncio + async def test_append_event_with_state_delta(self, service): + session = await service.create_session( + app_name="app1", user_id="u1", session_id="delta-test" + ) + event = Event( + invocation_id="inv1", + author="agent", + actions=EventActions( + state_delta={ + "app:color": "blue", + "user:lang": "en", + "local_key": "local_val", + } + ), + ) + await service.append_event(session=session, event=event) + + fetched = await service.get_session( + app_name="app1", user_id="u1", session_id="delta-test" + ) + assert fetched.state.get("app:color") == "blue" + assert fetched.state.get("user:lang") == "en" + assert fetched.state.get("local_key") == "local_val" + + @pytest.mark.asyncio + async def test_app_state_shared_across_sessions(self, service): + s1 = await service.create_session( + app_name="shared", + user_id="u1", + session_id="s1", + state={"app:version": "1.0"}, + ) + s2 = await service.create_session( + app_name="shared", user_id="u1", session_id="s2" + ) + assert s2.state.get("app:version") == "1.0" + + @pytest.mark.asyncio + async def test_user_state_shared_across_sessions(self, service): + s1 = await service.create_session( + app_name="app1", + user_id="u1", + session_id="us1", + state={"user:pref": "dark"}, + ) + s2 = await service.create_session( + app_name="app1", user_id="u1", session_id="us2" + ) + assert s2.state.get("user:pref") == "dark" + + @pytest.mark.asyncio + async def test_get_session_num_recent_events(self, service): + session = await service.create_session( + app_name="app1", user_id="u1", session_id="recent" + ) + for i in range(5): + event = Event( + invocation_id=f"inv{i}", + author="user", + timestamp=float(i + 1), + ) + await service.append_event(session=session, event=event) + + config = GetSessionConfig(num_recent_events=2) + fetched = await service.get_session( + app_name="app1", + user_id="u1", + session_id="recent", + config=config, + ) + assert len(fetched.events) == 2 + assert fetched.events[0].timestamp == 4.0 + assert fetched.events[1].timestamp == 5.0 + + @pytest.mark.asyncio + async def test_get_session_after_timestamp(self, service): + session = await service.create_session( + app_name="app1", user_id="u1", session_id="after" + ) + for i in range(5): + event = Event( + invocation_id=f"inv{i}", + author="user", + timestamp=float(i + 1), + ) + await service.append_event(session=session, event=event) + + config = GetSessionConfig(after_timestamp=3.0) + fetched = await service.get_session( + app_name="app1", + user_id="u1", + session_id="after", + config=config, + ) + assert len(fetched.events) == 3 + assert fetched.events[0].timestamp == 3.0 + + @pytest.mark.asyncio + async def test_close_and_context_manager(self, service): + async with service: + session = await service.create_session( + app_name="app1", user_id="u1" + ) + assert session is not None + + @pytest.mark.asyncio + async def test_temp_state_not_persisted(self, service): + session = await service.create_session( + app_name="app1", + user_id="u1", + session_id="temp-test", + state={"temp:scratch": "gone", "keep": "this"}, + ) + assert session.state.get("keep") == "this" + assert "temp:scratch" not in session.state + + @pytest.mark.asyncio + async def test_collection_prefix(self): + with patch( + "google.adk_community.sessions.firestore_session_service." + "FirestoreSessionService.__init__", + lambda self, **kw: None, + ): + from google.adk_community.sessions.firestore_session_service import ( + FirestoreSessionService, + ) + + svc = FirestoreSessionService() + svc._db = _FakeClient() + svc._prefix = "test_" + + session = await svc.create_session( + app_name="app1", user_id="u1", session_id="prefixed" + ) + assert session.id == "prefixed" + + fetched = await svc.get_session( + app_name="app1", user_id="u1", session_id="prefixed" + ) + assert fetched is not None