diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index ab5f6567a..0be78944b 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -91,8 +91,9 @@ public BaseLlmFlow( * RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the * events generated by them. */ - protected Flowable preprocess( + private Flowable preprocess( InvocationContext context, AtomicReference llmRequestRef) { + Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -114,6 +115,7 @@ protected Flowable preprocess( .concatMap( processor -> Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + .compose(Tracing.withContext(currentContext)) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -128,7 +130,8 @@ protected Flowable postprocess( InvocationContext context, Event baseEventForLlmResponse, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + Context parentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); @@ -144,15 +147,16 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } - Context parentContext = Context.current(); - return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, + eventIterables, + context, + baseEventForLlmResponse, + llmRequest, + parentContext) + .compose(Tracing.withContext(parentContext))); } /** @@ -163,54 +167,80 @@ protected Flowable postprocess( * @param eventForCallbackUsage An Event object primarily for providing context (like actions) to * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ - private Flowable callLlm( + private Flowable callLlm( Context spanContext, InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) .toFlowable() + .concatMap( + llmResp -> + postprocess( + context, + eventForCallbackUsage, + llmRequestBuilder.build(), + llmResp, + spanContext)) .switchIfEmpty( Flowable.defer( () -> { + LlmAgent agent = (LlmAgent) context.agent(); BaseLlm llm = agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose( - Tracing.trace("call_llm") - .setParent(spanContext) - .onSuccess( - (span, llmResp) -> - Tracing.traceCallLlm( - span, + LlmRequest finalLlmRequest = llmRequestBuilder.build(); + + Span span = + Tracing.getTracer() + .spanBuilder("call_llm") + .setParent(spanContext) + .startSpan(); + Context callLlmContext = spanContext.with(span); + + Flowable flowable = + llm.generateContent( + finalLlmRequest, + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnError( + error -> { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()) + .flatMap( + llmResp -> + postprocess( context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp))) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); + eventForCallbackUsage, + finalLlmRequest, + llmResp, + callLlmContext) + .doOnSubscribe( + s -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + finalLlmRequest, + llmResp))); + + return Tracing.traceFlowable(callLlmContext, span, () -> flowable); })); } @@ -222,6 +252,7 @@ private Flowable callLlm( */ private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -240,7 +271,11 @@ private Maybe handleBeforeModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); @@ -257,6 +292,7 @@ private Maybe handleOnModelErrorCallback( LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, Throwable throwable) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -277,7 +313,11 @@ private Maybe handleOnModelErrorCallback( () -> { LlmRequest llmRequest = llmRequestBuilder.build(); return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(currentContext))) .firstElement(); }); @@ -292,6 +332,7 @@ private Maybe handleOnModelErrorCallback( */ private Single handleAfterModelCallback( InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -310,7 +351,11 @@ private Single handleAfterModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); @@ -362,23 +407,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex context, llmRequestAfterPreprocess, mutableEventTemplate) - .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - String newId = Event.generateEventId(); - logger.debug( - "Resetting event ID from {} to {}", oldId, newId); - mutableEventTemplate.setId(newId); - }); - } + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + String newId = Event.generateEventId(); + logger.debug("Resetting event ID from {} to {}", oldId, newId); + mutableEventTemplate.setId(newId); }) .concatMap( event -> { @@ -545,6 +579,10 @@ public void onError(Throwable e) { .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)); + Span span = + Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan(); + Context callLlmContext = spanContext.with(span); + Flowable receiveFlow = connection .receive() @@ -556,7 +594,8 @@ public void onError(Throwable e) { invocationContext, baseEventForThisLlmResponse, llmRequestAfterPreprocess, - llmResponse); + llmResponse, + callLlmContext); }) .flatMap( event -> { @@ -592,7 +631,12 @@ public void onError(Throwable e) { } }); - return receiveFlow.takeWhile(event -> !event.actions().endInvocation().orElse(false)); + return Tracing.traceFlowable( + callLlmContext, + span, + () -> + receiveFlow.takeWhile( + event -> !event.actions().endInvocation().orElse(false))); })); } @@ -608,7 +652,8 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest) { + LlmRequest llmRequest, + Context parentContext) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -624,21 +669,23 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); + Flowable functionEvents; + try (Scope scope = parentContext.makeCurrent()) { + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + } return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index c1a996064..84a8141ea 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -42,7 +42,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; -import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Observable; @@ -163,7 +162,9 @@ public static Maybe handleFunctionCalls( } return functionResponseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -226,7 +227,9 @@ public static Maybe handleFunctionCallsLive( return responseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -243,47 +246,45 @@ private static Function> getFunctionCallMapper( Context parentContext) { return functionCall -> Maybe.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation( - functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = - functionCall.args().map(HashMap::new).orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> { - try (Scope innerScope = parentContext.makeCurrent()) { - return isLive - ? processFunctionLive( - invocationContext, - tool, - toolContext, - functionCall, - functionArgs, - parentContext) - : callTool(tool, functionArgs, toolContext, parentContext); - } - })); - - return postProcessFunctionResult( - maybeFunctionResult, - invocationContext, - tool, - functionArgs, - toolContext, - isLive, - parentContext); - } - }); + () -> { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation( + functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = + functionCall.args().map(HashMap::new).orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> + isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs, + parentContext) + : callTool( + tool, functionArgs, toolContext, parentContext)) + .compose(Tracing.withContext(parentContext))); + + return postProcessFunctionResult( + maybeFunctionResult, + invocationContext, + tool, + functionArgs, + toolContext, + isLive, + parentContext); + }) + .compose(Tracing.withContext(parentContext)); } /** @@ -410,34 +411,27 @@ private static Maybe postProcessFunctionResult( }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace("tool_response [" + tool.name() + "]") - .setParent(parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } - }); + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event event = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + Tracing.traceToolResponse(event.id(), event); + return Maybe.just(event); + }); + }) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index e00c0093d..a93eb3cb4 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -29,6 +29,7 @@ import com.google.adk.events.Event; import com.google.adk.events.ToolConfirmation; import com.google.adk.models.LlmRequest; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -37,6 +38,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.Collection; @@ -216,10 +218,13 @@ private Maybe assembleEvent( .build()) .build(); - return toolsMapSingle.flatMapMaybe( - toolsMap -> - Functions.handleFunctionCalls( - invocationContext, functionCallEvent, toolsMap, toolConfirmations)); + Context parentContext = Context.current(); + return toolsMapSingle + .flatMapMaybe( + toolsMap -> + Functions.handleFunctionCalls( + invocationContext, functionCallEvent, toolsMap, toolConfirmations)) + .compose(Tracing.withContext(parentContext)); } private static Optional> maybeCreateToolConfirmationEntry( diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index e534da787..8d0366e9a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,13 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -126,6 +128,7 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -136,11 +139,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e))) + .compose(Tracing.withContext(capturedContext)); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -149,7 +154,8 @@ public Completable close() { .doOnError( e -> logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + "[{}] Error during callback 'close'", plugin.getName(), e))) + .compose(Tracing.withContext(capturedContext)); } @Override @@ -227,7 +233,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -247,6 +253,7 @@ private Maybe runMaybeCallbacks( plugin.getName(), callbackName, e))) - .firstElement(); + .firstElement() + .compose(Tracing.withContext(capturedContext)); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 5859c4786..51e1b8f25 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,6 +52,7 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -375,20 +376,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -441,7 +447,8 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); + return runAsyncImpl(session, newMessage, runConfig, stateDelta) + .compose(Tracing.trace("invocation")); } /** @@ -461,6 +468,7 @@ protected Flowable runAsyncImpl( Preconditions.checkNotNull(session, "session cannot be null"); Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); + Context capturedContext = Context.current(); return Flowable.defer( () -> { BaseAgent rootAgent = this.agent; @@ -476,6 +484,7 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) + .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> @@ -500,7 +509,8 @@ protected Flowable runAsyncImpl( event, invocationId, runConfig, - rootAgent)); + rootAgent)) + .compose(Tracing.withContext(capturedContext)); }); }) .doOnError( @@ -508,8 +518,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -562,12 +571,14 @@ private Flowable runAgentWithFreshSession( .toFlowable()); // If beforeRunCallback returns content, emit it and skip agent + Context capturedContext = Context.current(); return beforeRunEvent .toFlowable() .switchIfEmpty(agentEvents) .concatWith( Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) - .concatWith(Completable.defer(() -> compactEvents(updatedSession))); + .concatWith(Completable.defer(() -> compactEvents(updatedSession))) + .compose(Tracing.withContext(capturedContext)); } private Completable compactEvents(Session session) { @@ -632,46 +643,9 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .agent(this.findAgentToRun(session, rootAgent)); } - /** - * Runs the agent in live mode, appending generated events to the session. - * - * @return stream of events from the agent. - */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); } /** @@ -682,19 +656,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** @@ -708,6 +688,49 @@ public Flowable runLive( return runLive(sessionKey.userId(), sessionKey.id(), liveRequestQueue, runConfig); } + /** + * Runs the agent in live mode, appending generated events to the session. + * + * @return stream of events from the agent. + */ + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return Flowable.defer( + () -> { + Context capturedContext = Context.current(); + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }) + .compose(Tracing.withContext(capturedContext)); + }); + } + /** * Runs the agent asynchronously with a default user ID. * diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 594e47fd8..d30b76aa5 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -572,8 +572,13 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + + // The tool calls and responses are children of the first LLM call that produced the function + // call. + String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); + toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolResponseSpans.forEach( + s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 4a0b345c6..ca8386053 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -43,9 +43,13 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; import java.util.Map; import java.util.Optional; @@ -572,6 +576,71 @@ public Single> runAsync(Map args, ToolContex } } + @Test + public void run_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + RequestProcessor requestProcessor = + (ctx, request) -> { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + ResponseProcessor responseProcessor = + (ctx, response) -> { + return Single.just(ResponseProcessingResult.create(response, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + return Maybe.empty().subscribeOn(Schedulers.computation()); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + return Maybe.just(resp).subscribeOn(Schedulers.computation()); + }; + + Callbacks.OnModelErrorCallback onErrorCallback = + (ctx, req, err) -> { + return Maybe.just( + LlmResponse.builder().content(Content.fromParts(Part.fromText("error"))).build()) + .subscribeOn(Schedulers.computation()); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .onModelErrorCallback(onErrorCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow(ImmutableList.of(requestProcessor), ImmutableList.of(responseProcessor)); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + baseLlmFlow + .run(invocationContext) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(events).hasSize(1); + assertThat(events.get(0).content()).hasValue(content); + } + @Test public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { GenerateContentResponseUsageMetadata usageMetadata = @@ -588,7 +657,12 @@ public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { List events = baseLlmFlow - .postprocess(invocationContext, baseEvent, LlmRequest.builder().build(), llmResponse) + .postprocess( + invocationContext, + baseEvent, + LlmRequest.builder().build(), + llmResponse, + io.opentelemetry.context.Context.current()) .toList() .blockingGet(); diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index 4ae856fc7..3771143cf 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -37,8 +37,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -144,6 +148,87 @@ public void onUserMessageCallback_pluginOrderRespected() { inOrder.verify(plugin2).onUserMessageCallback(mockInvocationContext, content); } + @Test + public void contextPropagation_runMaybeCallbacks() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content expectedContent = Content.builder().build(); + when(plugin1.onUserMessageCallback(any(), any())) + .thenReturn(Maybe.just(expectedContent).subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Maybe resultMaybe; + try (Scope scope = testContext.makeCurrent()) { + resultMaybe = pluginManager.onUserMessageCallback(mockInvocationContext, content); + } + + // Assert downstream operators have the propagated context + resultMaybe + .doOnSuccess( + result -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(expectedContent); + + verify(plugin1).onUserMessageCallback(mockInvocationContext, content); + } + + @Test + public void contextPropagation_afterRunCallback() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.afterRunCallback(any())) + .thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.afterRunCallback(mockInvocationContext); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).afterRunCallback(mockInvocationContext); + } + + @Test + public void contextPropagation_close() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.close()).thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.close(); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).close(); + } + @Test public void afterRunCallback_allComplete() { when(plugin1.afterRunCallback(any())).thenReturn(Completable.complete()); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 8a0a84b08..2eb515fa2 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -57,6 +57,9 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Completable; @@ -977,6 +980,84 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } + @Test + public void runAsync_createsToolSpansWithCorrectParent() { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + + var unused = + runnerWithTool + .runAsync( + sessionWithTool.sessionKey(), + createContent("from user"), + RunConfig.builder().build()) + .toList() + .blockingGet(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + assertThat(llmSpans).hasSize(2); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + + @Test + public void runLive_createsToolSpansWithCorrectParent() throws Exception { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber = + runnerWithTool + .runLive(sessionWithTool.sessionKey(), liveRequestQueue, RunConfig.builder().build()) + .test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + // In runLive, there is one call_llm span for the execution + assertThat(llmSpans).hasSize(1); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); @@ -1188,6 +1269,53 @@ public void close_closesPluginsAndCodeExecutors() { verify(plugin).close(); } + @Test + public void runAsync_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + runner + .runAsync("user", session.id(), createContent("test message")) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + } + + @Test + public void runLive_contextPropagation() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber; + try (Scope scope = testContext.makeCurrent()) { + testSubscriber = + runner + .runLive(session, liveRequestQueue, RunConfig.builder().build()) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test(); + } + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + } + @Test public void buildRunnerWithPlugins_success() { BasePlugin plugin1 = mockPlugin("test1");