Skip to content

Commit 19d3649

Browse files
authored
Update Google ADK extension (#163)
1 parent 14004a1 commit 19d3649

3 files changed

Lines changed: 59 additions & 20 deletions

File tree

python/restate/ext/adk/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,33 @@
88
# directory of this repository or package, or at
99
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
1010
#
11+
import typing
1112

1213
from .session import RestateSessionService
1314
from .plugin import RestatePlugin
15+
from restate import ObjectContext, Context
16+
from restate.extensions import current_context
17+
18+
19+
def restate_object_context() -> ObjectContext:
20+
"""Get the current Restate ObjectContext."""
21+
ctx = current_context()
22+
if ctx is None:
23+
raise RuntimeError("No Restate context found.")
24+
return typing.cast(ObjectContext, ctx)
25+
26+
27+
def restate_context() -> Context:
28+
"""Get the current Restate Context."""
29+
ctx = current_context()
30+
if ctx is None:
31+
raise RuntimeError("No Restate context found.")
32+
return ctx
33+
1434

1535
__all__ = [
1636
"RestateSessionService",
1737
"RestatePlugin",
38+
"restate_object_context",
39+
"restate_context",
1840
]

python/restate/ext/adk/plugin.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from google.adk.agents import BaseAgent, LlmAgent
2424
from google.adk.agents.callback_context import CallbackContext
25+
from google.adk.agents.invocation_context import InvocationContext
2526
from google.adk.plugins import BasePlugin
2627
from google.adk.tools.base_tool import BaseTool
2728
from google.adk.tools.tool_context import ToolContext
@@ -30,12 +31,10 @@
3031
from google.adk.models import LLMRegistry
3132
from google.adk.models.base_llm import BaseLlm
3233
from google.adk.flows.llm_flows.functions import generate_client_function_call_id
33-
34+
from restate.ext.adk import RestateSessionService
3435

3536
from restate.extensions import current_context
3637

37-
from .session import flush_session_state
38-
3938

4039
class RestatePlugin(BasePlugin):
4140
"""A plugin to integrate Restate with the ADK framework."""
@@ -84,12 +83,13 @@ async def after_agent_callback(
8483
) -> Optional[types.Content]:
8584
self._models.pop(callback_context.invocation_id, None)
8685
self._locks.pop(callback_context.invocation_id, None)
87-
88-
ctx = cast(restate.ObjectContext, current_context())
89-
await flush_session_state(ctx, callback_context.session)
90-
9186
return None
9287

88+
async def after_run_callback(self, *, invocation_context: InvocationContext) -> None:
89+
if isinstance(invocation_context.session_service, RestateSessionService):
90+
restate_session_service = cast(RestateSessionService, invocation_context.session_service)
91+
await restate_session_service.flush_session_state(invocation_context.session)
92+
9393
async def before_model_callback(
9494
self, *, callback_context: CallbackContext, llm_request: LlmRequest
9595
) -> Optional[LlmResponse]:
@@ -109,9 +109,10 @@ async def before_tool_callback(
109109
tool_args: dict[str, Any],
110110
tool_context: ToolContext,
111111
) -> Optional[dict]:
112-
tool_context.session.state["restate_context"] = current_context()
113112
lock = self._locks[tool_context.invocation_id]
113+
ctx = current_context()
114114
await lock.acquire()
115+
tool_context.session.state["restate_context"] = ctx
115116
# TODO: if we want we can also automatically wrap tools with ctx.run_typed here
116117
return None
117118

@@ -123,9 +124,9 @@ async def after_tool_callback(
123124
tool_context: ToolContext,
124125
result: dict,
125126
) -> Optional[dict]:
127+
tool_context.session.state.pop("restate_context", None)
126128
lock = self._locks[tool_context.invocation_id]
127129
lock.release()
128-
tool_context.session.state.pop("restate_context", None)
129130
return None
130131

131132
async def on_tool_error_callback(
@@ -136,9 +137,9 @@ async def on_tool_error_callback(
136137
tool_context: ToolContext,
137138
error: Exception,
138139
) -> Optional[dict]:
140+
tool_context.session.state.pop("restate_context", None)
139141
lock = self._locks[tool_context.invocation_id]
140142
lock.release()
141-
tool_context.session.state.pop("restate_context", None)
142143
return None
143144

144145
async def close(self):

python/restate/ext/adk/session.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
from typing import Optional, Any, cast
1818

1919
from google.adk.sessions import Session
20+
from google.adk.sessions.state import State
2021
from google.adk.events.event import Event
2122
from google.adk.sessions.base_session_service import (
2223
BaseSessionService,
2324
ListSessionsResponse,
2425
GetSessionConfig,
2526
)
27+
from restate import TerminalError
2628

2729
from restate.extensions import current_context
2830

@@ -42,15 +44,23 @@ async def create_session(
4244
if session_id is None:
4345
session_id = str(self.ctx().uuid())
4446

45-
session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session) or Session(
47+
session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session)
48+
if session is not None:
49+
raise TerminalError("Session with the given ID already exists.")
50+
51+
session = Session(
4652
app_name=app_name,
4753
user_id=user_id,
4854
id=session_id,
4955
state=state or {},
5056
)
51-
self.ctx().set(f"session_store::{session_id}", session)
57+
58+
await self.flush_session_state(session)
5259
return session
5360

61+
async def has_session(self, *, session_id: str) -> bool:
62+
return await self.ctx().get(f"session_store::{session_id}", type_hint=Session) is not None
63+
5464
async def get_session(
5565
self,
5666
*,
@@ -89,12 +99,18 @@ async def append_event(self, session: Session, event: Event) -> Event:
8999
session.events.append(event)
90100
return event
91101

102+
async def flush_session_state(self, session: Session):
103+
session_to_store = session.model_copy()
92104

93-
async def flush_session_state(ctx: restate.ObjectContext, session: Session):
94-
session_to_store = session.model_copy()
95-
# Remove restate-specific context that got added by the plugin before storing
96-
session_to_store.state.pop("restate_context", None)
97-
deterministic_session = await ctx.run_typed(
98-
"store session", lambda: session_to_store, restate.RunOptions(type_hint=Session)
99-
)
100-
ctx.set(f"session_store::{session.id}", deterministic_session)
105+
# Remove temporary state keys before storing
106+
for key in list(session_to_store.state.keys()):
107+
if key.startswith(State.TEMP_PREFIX):
108+
session_to_store.state.pop(key)
109+
110+
# Remove restate-specific context that got added by the plugin before storing
111+
session_to_store.state.pop("restate_context", None)
112+
113+
deterministic_session = await self.ctx().run_typed(
114+
"store session", lambda: session_to_store, restate.RunOptions(type_hint=Session)
115+
)
116+
self.ctx().set(f"session_store::{session.id}", deterministic_session)

0 commit comments

Comments
 (0)