diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index d9bb047a3..3e96f959e 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -186,13 +186,21 @@ public Completable deleteSession(String appName, String userId, String sessionId Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - ConcurrentMap> appSessionsMap = sessions.get(appName); - if (appSessionsMap != null) { - ConcurrentMap userSessionsMap = appSessionsMap.get(userId); - if (userSessionsMap != null) { - userSessionsMap.remove(sessionId); - } - } + sessions.computeIfPresent( + appName, + (app, appSessionsMap) -> { + appSessionsMap.computeIfPresent( + userId, + (user, userSessionsMap) -> { + userSessionsMap.remove(sessionId); + // If userSessionsMap is now empty, return null to automatically remove the userId + // key + return userSessionsMap.isEmpty() ? null : userSessionsMap; + }); + // If appSessionsMap is now empty, return null to automatically remove the appName key + return appSessionsMap.isEmpty() ? null : appSessionsMap; + }); + return Completable.complete(); } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 0d9235b1b..c260f6695 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.lang.reflect.Field; import java.time.Instant; import java.util.HashMap; import java.util.Optional; @@ -266,4 +267,54 @@ public void sequentialAgents_shareTempState() { assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output"); assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data"); } + + @Test + public void deleteSession_cleansUpEmptyParentMaps() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + + Session session = sessionService.createSession("app-name", "user-id").blockingGet(); + + sessionService + .deleteSession(session1.appName(), session1.userId(), session1.id()) + .blockingAwait(); + + // Use reflection to access the private 'sessions' field + Field field = InMemorySessionService.class.getDeclaredField("sessions"); + field.setAccessible(true); + ConcurrentMap sessions = (ConcurrentMap) field.get(sessionService); + + // After deleting the only session for "user-id" under "app-name", + // both the userId map and the appName map should have been removed + assertThat(sessions).isEmpty(); + } + + @Test + public void deleteSession_doesNotRemoveUserMapWhenOtherSessionsExist() throws Exception { + InMemorySessionService sessionService = new InMemorySessionService(); + + Session session1 = sessionService.createSession("app-name", "user-id").blockingGet(); + Session session2 = sessionService.createSession("app-name", "user-id").blockingGet(); + + // Delete only one of the two sessions + sessionService.deleteSession(session1.appName(), session1.userId(), session1.id()).blockingAwait(); + + // session2 should still be retrievable + assertThat( + sessionService + .getSession(session2.appName(), session2.userId(), session2.id(), Optional.empty()) + .blockingGet()) + .isNotNull(); + + // The userId entry should still exist (not pruned) because session2 remains + Field field = InMemorySessionService.class.getDeclaredField("sessions"); + field.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap>> sessions = + (ConcurrentMap>>) + field.get(sessionService); + + assertThat(sessions.get("app-name")).isNotNull(); + assertThat(sessions.get("app-name").get("user-id")).isNotNull(); + assertThat(sessions.get("app-name").get("user-id")).hasSize(1); + } }