diff --git a/src/google/adk/cli/utils/local_storage.py b/src/google/adk/cli/utils/local_storage.py index 06b565fd5b..78de947acd 100644 --- a/src/google/adk/cli/utils/local_storage.py +++ b/src/google/adk/cli/utils/local_storage.py @@ -18,6 +18,7 @@ import asyncio import logging from pathlib import Path +from typing import Any from typing import Mapping from typing import Optional @@ -205,6 +206,13 @@ async def delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + service = await self._get_service(app_name) + return await service.get_user_state(app_name=app_name, user_id=user_id) + @override async def append_event(self, session: Session, event: Event) -> Event: service = await self._get_service(session.app_name) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 324abe230b..06eb6a2534 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -111,6 +111,46 @@ async def delete_session( ) -> None: """Deletes a session.""" + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + """Returns the user-scoped state for the given app and user. + + User state is keyed by ``(app_name, user_id)`` and shared across all + sessions of the same user within the same app. The returned dictionary + uses raw keys **without** the ``user:`` prefix (e.g. ``"my_key"`` rather + than ``"user:my_key"``). + + This method exists so that callers can read user state without holding an + active ``session_id``. A common use case is bootstrapping context at the + start of a new session before calling ``create_session``, which would + otherwise require an expensive ``list_sessions`` call just to access + user-scoped data. + + Returns an empty dict when no user state has been stored for this + ``(app_name, user_id)`` combination. + + Args: + app_name: The name of the app. + user_id: The ID of the user. + + Returns: + A dictionary of raw (un-prefixed) user-scoped key/value pairs, or an + empty dict when no user state exists. + + Raises: + NotImplementedError: When the concrete ``BaseSessionService`` + implementation does not support reading user state independently of a + session. Callers should catch this, then enumerate sessions via + ``list_sessions`` and call ``get_session`` on each result to access + the merged state, or accept that user state is unavailable. + """ + raise NotImplementedError( + f'{type(self).__name__} does not support get_user_state. ' + 'To read user state, enumerate sessions via list_sessions and ' + 'call get_session on each result to access the merged state.' + ) + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 331683af92..2fdd673334 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -329,7 +329,7 @@ async def _with_session_lock( else: self._session_lock_ref_count[lock_key] = remaining - async def _prepare_tables(self): + async def _prepare_tables(self) -> None: """Ensure database tables are ready for use. This method is called lazily before each database operation. It checks the @@ -627,6 +627,22 @@ async def delete_session( await sql_session.execute(stmt) await sql_session.commit() + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + await self._prepare_tables() + schema = self._get_schema_classes() + async with self._rollback_on_exception_session( + read_only=True + ) as sql_session: + storage_user_state = await sql_session.get( + schema.StorageUserState, (app_name, user_id) + ) + if storage_user_state is None: + return {} + return dict(storage_user_state.state or {}) + @override async def append_event(self, session: Session, event: Event) -> Event: await self._prepare_tables() diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index b8f6cfab46..73a54f398b 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -312,6 +312,12 @@ def _delete_session_impl( self.sessions[app_name][user_id].pop(session_id) + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + return dict(self.user_state.get(app_name, {}).get(user_id, {})) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 798befcedc..6e0f60db1d 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -359,6 +359,13 @@ async def delete_session( ) await db.commit() + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + async with self._get_db_connection() as db: + return await self._get_user_state(db, app_name, user_id) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index e7e1dc052a..5026bc6ebb 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -276,6 +276,27 @@ async def delete_session( logger.error('Error deleting session %s: %s', session_id, e) raise + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + """Not supported by the Vertex AI Agent Engine backend. + + The Vertex AI Agent Engine API does not expose user state independently of + a session. To read user state, enumerate sessions via ``list_sessions`` + and call ``get_session`` on each result to access the merged state. + + Raises: + NotImplementedError: Always, because the Vertex AI Agent Engine API does + not provide a way to query user state without a session. + """ + raise NotImplementedError( + 'VertexAiSessionService does not support get_user_state. ' + 'The Vertex AI Agent Engine API does not expose user state ' + 'independently of a session. To read user state, enumerate sessions ' + 'via list_sessions and call get_session on each result.' + ) + @override async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. diff --git a/tests/unittests/cli/utils/test_local_storage.py b/tests/unittests/cli/utils/test_local_storage.py index a11a82fad7..2cdc463e5c 100644 --- a/tests/unittests/cli/utils/test_local_storage.py +++ b/tests/unittests/cli/utils/test_local_storage.py @@ -19,6 +19,8 @@ from google.adk.cli.utils.local_storage import create_local_database_session_service from google.adk.cli.utils.local_storage import create_local_session_service from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.sessions.sqlite_session_service import SqliteSessionService import pytest @@ -90,3 +92,28 @@ def test_create_local_database_session_service_returns_sqlite( service = create_local_database_session_service(base_dir=tmp_path) assert isinstance(service, SqliteSessionService) + + +@pytest.mark.asyncio +async def test_per_agent_session_service_get_user_state(tmp_path: Path) -> None: + agent_a = tmp_path / 'agent_a' + agent_b = tmp_path / 'agent_b' + agent_a.mkdir() + agent_b.mkdir() + + service = PerAgentDatabaseSessionService(agents_root=tmp_path) + + session_a = await service.create_session(app_name='agent_a', user_id='user_a') + await service.append_event( + session_a, + Event( + author='system', + actions=EventActions(state_delta={'user:profile': {'name': 'Alice'}}), + ), + ) + + state_a = await service.get_user_state(app_name='agent_a', user_id='user_a') + state_b = await service.get_user_state(app_name='agent_b', user_id='user_b') + + assert state_a == {'profile': {'name': 'Alice'}} + assert state_b == {} diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 02f5159a45..fc9009438a 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -30,6 +30,7 @@ from google.adk.sessions.database_session_service import DatabaseSessionService from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.sqlite_session_service import SqliteSessionService +from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService from google.genai import types import pytest from sqlalchemy import delete @@ -1650,3 +1651,124 @@ async def tracking_fn(**kwargs): finally: database_session_service._select_required_state = original_fn await service.close() + + +@pytest.mark.asyncio +async def test_get_user_state_returns_empty_dict_when_no_state_exists( + session_service, +): + state = await session_service.get_user_state( + app_name='my_app', user_id='u1' + ) + assert state == {} + + +@pytest.mark.asyncio +async def test_get_user_state_returns_state_written_via_append_event( + session_service, +): + session = await session_service.create_session( + app_name='my_app', user_id='u1' + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions( + state_delta={'user:profile': {'name': 'Alice'}, 'session_key': 1} + ), + ), + ) + + state = await session_service.get_user_state(app_name='my_app', user_id='u1') + + assert state == {'profile': {'name': 'Alice'}} + assert 'session_key' not in state + + +@pytest.mark.asyncio +async def test_get_user_state_is_not_visible_across_users(session_service): + session = await session_service.create_session( + app_name='my_app', user_id='u1' + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:secret': 'only-for-u1'}), + ), + ) + + other_state = await session_service.get_user_state( + app_name='my_app', user_id='u2' + ) + assert other_state == {} + + +@pytest.mark.asyncio +async def test_get_user_state_is_not_visible_across_apps(session_service): + session = await session_service.create_session( + app_name='my_app', user_id='u1' + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:data': 'only-app-a'}), + ), + ) + + other_state = await session_service.get_user_state( + app_name='other_app', user_id='u1' + ) + assert other_state == {} + + +@pytest.mark.asyncio +async def test_get_user_state_available_before_session_is_created( + session_service, +): + first_session = await session_service.create_session( + app_name='my_app', user_id='u1' + ) + await session_service.append_event( + first_session, + Event( + author='system', + actions=EventActions(state_delta={'user:ctx': {'v': 1}}), + ), + ) + + state = await session_service.get_user_state(app_name='my_app', user_id='u1') + assert state == {'ctx': {'v': 1}} + + +@pytest.mark.asyncio +async def test_get_user_state_reflects_latest_write(session_service): + session = await session_service.create_session( + app_name='my_app', user_id='u1' + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:counter': 1}), + ), + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:counter': 2}), + ), + ) + + state = await session_service.get_user_state(app_name='my_app', user_id='u1') + assert state['counter'] == 2 + + +@pytest.mark.asyncio +async def test_vertex_ai_session_service_raises_not_implemented_for_get_user_state(): + service = VertexAiSessionService(project='proj', location='us-central1') + with pytest.raises(NotImplementedError): + await service.get_user_state(app_name='my_app', user_id='u1')