2222
2323from google .adk .agents import BaseAgent , LlmAgent
2424from google .adk .agents .callback_context import CallbackContext
25+ from google .adk .agents .invocation_context import InvocationContext
2526from google .adk .plugins import BasePlugin
2627from google .adk .tools .base_tool import BaseTool
2728from google .adk .tools .tool_context import ToolContext
3031from google .adk .models import LLMRegistry
3132from google .adk .models .base_llm import BaseLlm
3233from google .adk .flows .llm_flows .functions import generate_client_function_call_id
33-
34+ from restate . ext . adk import RestateSessionService
3435
3536from restate .extensions import current_context
3637
37- from .session import flush_session_state
38-
3938
4039class RestatePlugin (BasePlugin ):
4140 """A plugin to integrate Restate with the ADK framework."""
@@ -84,12 +83,13 @@ async def after_agent_callback(
8483 ) -> Optional [types .Content ]:
8584 self ._models .pop (callback_context .invocation_id , None )
8685 self ._locks .pop (callback_context .invocation_id , None )
87-
88- ctx = cast (restate .ObjectContext , current_context ())
89- await flush_session_state (ctx , callback_context .session )
90-
9186 return None
9287
88+ async def after_run_callback (self , * , invocation_context : InvocationContext ) -> None :
89+ if isinstance (invocation_context .session_service , RestateSessionService ):
90+ restate_session_service = cast (RestateSessionService , invocation_context .session_service )
91+ await restate_session_service .flush_session_state (invocation_context .session )
92+
9393 async def before_model_callback (
9494 self , * , callback_context : CallbackContext , llm_request : LlmRequest
9595 ) -> Optional [LlmResponse ]:
@@ -109,9 +109,10 @@ async def before_tool_callback(
109109 tool_args : dict [str , Any ],
110110 tool_context : ToolContext ,
111111 ) -> Optional [dict ]:
112- tool_context .session .state ["restate_context" ] = current_context ()
113112 lock = self ._locks [tool_context .invocation_id ]
113+ ctx = current_context ()
114114 await lock .acquire ()
115+ tool_context .session .state ["restate_context" ] = ctx
115116 # TODO: if we want we can also automatically wrap tools with ctx.run_typed here
116117 return None
117118
@@ -123,9 +124,9 @@ async def after_tool_callback(
123124 tool_context : ToolContext ,
124125 result : dict ,
125126 ) -> Optional [dict ]:
127+ tool_context .session .state .pop ("restate_context" , None )
126128 lock = self ._locks [tool_context .invocation_id ]
127129 lock .release ()
128- tool_context .session .state .pop ("restate_context" , None )
129130 return None
130131
131132 async def on_tool_error_callback (
@@ -136,9 +137,9 @@ async def on_tool_error_callback(
136137 tool_context : ToolContext ,
137138 error : Exception ,
138139 ) -> Optional [dict ]:
140+ tool_context .session .state .pop ("restate_context" , None )
139141 lock = self ._locks [tool_context .invocation_id ]
140142 lock .release ()
141- tool_context .session .state .pop ("restate_context" , None )
142143 return None
143144
144145 async def close (self ):
0 commit comments