Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import abc
import copy
from typing import Any
from typing import Optional

Expand Down Expand Up @@ -102,6 +103,94 @@ async def delete_session(
) -> None:
"""Deletes a session."""

@abc.abstractmethod
async def clone_session(
self,
*,
app_name: str,
src_user_id: str,
src_session_id: Optional[str] = None,
new_user_id: Optional[str] = None,
new_session_id: Optional[str] = None,
) -> Session:
"""Clones session(s) and their events to a new session.

This method supports two modes:

1. Single session clone: When `src_session_id` is provided, clones that
specific session to the new session.

2. All sessions clone: When `src_session_id` is NOT provided, finds all
sessions for `src_user_id` and merges ALL their events into a single
new session.

Events are automatically deduplicated by event ID - only the first
occurrence of each event ID is kept.

Args:
app_name: The name of the app.
src_user_id: The source user ID whose session(s) to clone.
src_session_id: The source session ID to clone. If not provided, all
sessions for the source user will be merged into one new session.
new_user_id: The user ID for the new session. If not provided, uses the
same user_id as the source.
new_session_id: The session ID for the new session. If not provided, a
new ID will be auto-generated (UUID4).

Returns:
The newly created session with cloned events.

Raises:
ValueError: If no source sessions are found.
AlreadyExistsError: If a session with new_session_id already exists.
"""

def _prepare_sessions_for_cloning(
self, source_sessions: list[Session]
) -> tuple[dict[str, Any], list[Event]]:
"""Prepares source sessions for cloning by merging states and deduplicating events.

This is a shared helper method used by all clone_session implementations
to ensure consistent behavior across different session service backends.

The method:
1. Sorts sessions by last_update_time for deterministic state merging
2. Merges states from all sessions (later sessions overwrite earlier ones)
3. Collects all events, sorts by timestamp, and deduplicates by event ID

Args:
source_sessions: List of source sessions to process.

Returns:
A tuple of (merged_state, deduplicated_events):
- merged_state: Combined state from all sessions (deep copied)
- deduplicated_events: Chronologically sorted, deduplicated events
"""
# Sort sessions by update time for deterministic state merging
# Use sorted() to avoid modifying the input list in-place
sorted_sessions = sorted(source_sessions, key=lambda s: s.last_update_time)

# Merge states from all source sessions
merged_state: dict[str, Any] = {}
for session in sorted_sessions:
merged_state.update(copy.deepcopy(session.state))

# Collect all events, sort by timestamp, then deduplicate
# to ensure chronological "first occurrence wins"
all_source_events: list[Event] = []
for session in sorted_sessions:
all_source_events.extend(session.events)
all_source_events.sort(key=lambda e: e.timestamp)

all_events: list[Event] = []
seen_event_ids: set[str] = set()
for event in all_source_events:
if event.id not in seen_event_ids:
seen_event_ids.add(event.id)
all_events.append(event)

return merged_state, all_events

async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
Expand Down
96 changes: 94 additions & 2 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,98 @@ async def delete_session(
await sql_session.execute(stmt)
await sql_session.commit()

@override
async def clone_session(
self,
*,
app_name: str,
src_user_id: str,
src_session_id: Optional[str] = None,
new_user_id: Optional[str] = None,
new_session_id: Optional[str] = None,
) -> Session:
await self._prepare_tables()

# Use source values as defaults
new_user_id = new_user_id or src_user_id

schema = self._get_schema_classes()

# Collect source sessions and their events
source_sessions = []
if src_session_id:
# Single session clone - use get_session (no N+1 issue)
session = await self.get_session(
app_name=app_name,
user_id=src_user_id,
session_id=src_session_id,
)
if not session:
raise ValueError(
f"Source session {src_session_id} not found for user {src_user_id}."
)
source_sessions.append(session)
else:
# All sessions clone - optimized to avoid N+1 query problem
# Step 1: Get all sessions with state (no events)
list_response = await self.list_sessions(
app_name=app_name, user_id=src_user_id
)
if not list_response.sessions:
raise ValueError(f"No sessions found for user {src_user_id}.")

session_ids = [sess.id for sess in list_response.sessions]

# Step 2: Fetch ALL events for all session IDs in a single query
async with self.database_session_factory() as sql_session:
stmt = (
select(schema.StorageEvent)
.filter(schema.StorageEvent.app_name == app_name)
.filter(schema.StorageEvent.user_id == src_user_id)
.filter(schema.StorageEvent.session_id.in_(session_ids))
.order_by(schema.StorageEvent.timestamp.asc())
)
result = await sql_session.execute(stmt)
all_storage_events = result.scalars().all()

# Step 3: Map events back to sessions
events_by_session_id = {}
for storage_event in all_storage_events:
events_by_session_id.setdefault(storage_event.session_id, []).append(
storage_event.to_event()
)

# Build full session objects with events
for sess in list_response.sessions:
sess.events = events_by_session_id.get(sess.id, [])
source_sessions.append(sess)

# Use shared helper for state merging and event deduplication
merged_state, all_events = self._prepare_sessions_for_cloning(
source_sessions
)

# Create the new session (new_session_id=None triggers UUID4 generation)
new_session = await self.create_session(
app_name=app_name,
user_id=new_user_id,
state=merged_state,
session_id=new_session_id,
)

# Copy events to the new session using bulk insert
async with self.database_session_factory() as sql_session:
new_storage_events = [
schema.StorageEvent.from_event(new_session, copy.deepcopy(event))
for event in all_events
]
sql_session.add_all(new_storage_events)
await sql_session.commit()

# Return the new session with events (avoid redundant DB query)
new_session.events = all_events
return new_session

@override
async def append_event(self, session: Session, event: Event) -> Event:
await self._prepare_tables()
Expand All @@ -436,8 +528,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
if storage_session.update_timestamp_tz > session.last_update_time:
raise ValueError(
"The last_update_time provided in the session object"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
" earlier than the update_time in the storage_session"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'}"
" is earlier than the update_time in the storage_session"
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
" Please check if it is a stale session."
)
Expand Down
98 changes: 98 additions & 0 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,104 @@ def _delete_session_impl(

self.sessions[app_name][user_id].pop(session_id)

@override
async def clone_session(
self,
*,
app_name: str,
src_user_id: str,
src_session_id: Optional[str] = None,
new_user_id: Optional[str] = None,
new_session_id: Optional[str] = None,
) -> Session:
return self._clone_session_impl(
app_name=app_name,
src_user_id=src_user_id,
src_session_id=src_session_id,
new_user_id=new_user_id,
new_session_id=new_session_id,
)

def _clone_session_impl(
self,
*,
app_name: str,
src_user_id: str,
src_session_id: Optional[str] = None,
new_user_id: Optional[str] = None,
new_session_id: Optional[str] = None,
) -> Session:
# Use source values as defaults
new_user_id = new_user_id or src_user_id

# Collect source sessions and their events
source_sessions = []
if src_session_id:
# Single session clone
session = self._get_session_impl(
app_name=app_name,
user_id=src_user_id,
session_id=src_session_id,
)
if not session:
raise ValueError(
f'Source session {src_session_id} not found for user {src_user_id}.'
)
source_sessions.append(session)
else:
# All sessions clone - optimized direct access to avoid N+1 lookups
if (
app_name not in self.sessions
or src_user_id not in self.sessions[app_name]
):
raise ValueError(f'No sessions found for user {src_user_id}.')

user_sessions = self.sessions[app_name][src_user_id]
if not user_sessions:
raise ValueError(f'No sessions found for user {src_user_id}.')

# Directly access storage sessions and build full session objects
for session_id, storage_session in user_sessions.items():
# Deep copy the session to avoid mutations
copied_session = copy.deepcopy(storage_session)
# Merge state with app and user state
copied_session = self._merge_state(
app_name, src_user_id, copied_session
)
source_sessions.append(copied_session)

# Use shared helper for state merging and event deduplication
merged_state, all_events = self._prepare_sessions_for_cloning(
source_sessions
)
# Deep copy events for in-memory storage isolation
all_events = [copy.deepcopy(event) for event in all_events]

# Create the new session (new_session_id=None triggers UUID4 generation)
new_session = self._create_session_impl(
app_name=app_name,
user_id=new_user_id,
state=merged_state,
session_id=new_session_id,
)

# Get latest update time explicitly (don't rely on sorting side effects)
latest_update_time = (
max(s.last_update_time for s in source_sessions)
if source_sessions
else 0.0
)

# Get the storage session and set events
storage_session = self.sessions[app_name][new_user_id][new_session.id]
storage_session.events = all_events
storage_session.last_update_time = latest_update_time

# Return the new session with events (avoid redundant lookup)
new_session.events = all_events
new_session.last_update_time = latest_update_time
return new_session

@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
Expand Down
Loading