diff --git a/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/org/a2aproject/sdk/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java b/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/org/a2aproject/sdk/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java index 61676ff00..a35b58ad1 100644 --- a/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/org/a2aproject/sdk/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java +++ b/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/org/a2aproject/sdk/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java @@ -397,11 +397,11 @@ public void testMultiInstanceEventReplication() throws Exception { getClient1().subscribeToTask(new TaskIdParams(taskId), List.of(app1Subscriber), app1ErrorHandler); getClient2().subscribeToTask(new TaskIdParams(taskId), List.of(app2Subscriber), app2ErrorHandler); - // Wait for subscriptions to be established - at least one event should arrive on each + // Wait for subscriptions to be established - initial TaskEvent should arrive on each await() .atMost(Duration.ofSeconds(10)) .pollInterval(Duration.ofMillis(500)) - .until(() -> app1EventCount.get() >= 1 && app2EventCount.get() >= 1); + .until(() -> app1ReceivedInitialTask.get() && app2ReceivedInitialTask.get()); // Step 3: Send message on app1 (should generate TaskArtifactUpdateEvent) int app1BeforeMsg1 = app1EventCount.get(); diff --git a/http-client/src/main/java/org/a2aproject/sdk/client/http/JdkA2AHttpClient.java b/http-client/src/main/java/org/a2aproject/sdk/client/http/JdkA2AHttpClient.java index 47c7f0e32..dc0887ddb 100644 --- a/http-client/src/main/java/org/a2aproject/sdk/client/http/JdkA2AHttpClient.java +++ b/http-client/src/main/java/org/a2aproject/sdk/client/http/JdkA2AHttpClient.java @@ -176,16 +176,15 @@ public void onNext(String item) { @Override public void onError(Throwable throwable) { if (errorNotified.compareAndSet(false, true)) { - errorConsumer.accept(throwable); - } - if (subscription != null) { - subscription.cancel(); + if (!(throwable instanceof java.util.concurrent.CancellationException)) { + errorConsumer.accept(throwable); + } } } @Override public void onComplete() { - if (!errorNotified.get()) { + if (errorNotified.compareAndSet(false, true)) { if (useSseParser.get()) { sseParser.flush(); } else { @@ -196,9 +195,6 @@ public void onComplete() { } completeRunnable.run(); } - if (subscription != null) { - subscription.cancel(); - } } }; @@ -251,7 +247,9 @@ public void onComplete() { .handle((response, throwable) -> { if (throwable != null && errorNotified.compareAndSet(false, true)) { Throwable cause = throwable.getCause() != null ? throwable.getCause() : throwable; - errorConsumer.accept(cause); + if (!(cause instanceof java.util.concurrent.CancellationException)) { + errorConsumer.accept(cause); + } } return null; }); diff --git a/http-client/src/test/java/org/a2aproject/sdk/client/http/JdkA2AHttpClientTest.java b/http-client/src/test/java/org/a2aproject/sdk/client/http/JdkA2AHttpClientTest.java index a9c0c62d0..63bd29a8a 100644 --- a/http-client/src/test/java/org/a2aproject/sdk/client/http/JdkA2AHttpClientTest.java +++ b/http-client/src/test/java/org/a2aproject/sdk/client/http/JdkA2AHttpClientTest.java @@ -6,16 +6,39 @@ import org.mockserver.integration.ClientAndServer; import java.io.IOException; +import java.net.Authenticator; +import java.net.CookieHandler; import java.net.Proxy; import java.net.ProxySelector; import java.net.SocketAddress; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.WebSocket; +import java.time.Duration; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockserver.model.HttpRequest.request; import static org.mockserver.model.HttpResponse.response; @@ -74,6 +97,120 @@ public void testConstructorRejectsNullHttpClient() { assertThrows(IllegalArgumentException.class, () -> new JdkA2AHttpClient(null), "foo"); } + @Test + public void testCancellationExceptionViaSubscriberOnErrorIsNotPropagated() throws Exception { + AtomicReference capturedError = new AtomicReference<>(); + AtomicBoolean completed = new AtomicBoolean(false); + CountDownLatch errorPathReached = new CountDownLatch(1); + + HttpClient fakeClient = new StubHttpClient() { + @Override + public CompletableFuture> sendAsync( + HttpRequest request, HttpResponse.BodyHandler handler) { + HttpResponse.BodySubscriber bodySubscriber = + handler.apply(new FakeResponseInfo(200, "text/plain")); + bodySubscriber.onSubscribe(new NoOpSubscription()); + bodySubscriber.onError(new CancellationException()); + errorPathReached.countDown(); + return new CompletableFuture<>(); // never completes + } + }; + + new JdkA2AHttpClient(fakeClient) + .createGet() + .url("http://example.com/sse") + .getAsyncSSE(event -> {}, e -> capturedError.set(e), () -> completed.set(true)); + + assertTrue(errorPathReached.await(5, TimeUnit.SECONDS)); + assertNull(capturedError.get(), "CancellationException should not reach the error consumer"); + assertFalse(completed.get(), "Complete handler must not be called after cancellation"); + } + + @Test + public void testCancellationExceptionViaFutureFailureIsNotPropagated() throws Exception { + AtomicReference capturedError = new AtomicReference<>(); + AtomicBoolean completed = new AtomicBoolean(false); + + HttpClient fakeClient = new StubHttpClient() { + @Override + public CompletableFuture> sendAsync( + HttpRequest request, HttpResponse.BodyHandler handler) { + CompletableFuture> future = new CompletableFuture<>(); + future.completeExceptionally(new CancellationException()); + return future; + } + }; + + CompletableFuture result = new JdkA2AHttpClient(fakeClient) + .createGet() + .url("http://example.com/sse") + .getAsyncSSE(event -> {}, e -> capturedError.set(e), () -> completed.set(true)); + + result.get(5, TimeUnit.SECONDS); + assertNull(capturedError.get(), "CancellationException should not reach the error consumer"); + assertFalse(completed.get(), "Complete handler must not be called after cancellation"); + } + + @Test + public void testRealErrorsAreStillPropagatedToErrorConsumer() throws Exception { + AtomicReference capturedError = new AtomicReference<>(); + CountDownLatch errorLatch = new CountDownLatch(1); + IOException expectedError = new IOException("connection refused"); + + HttpClient fakeClient = new StubHttpClient() { + @Override + public CompletableFuture> sendAsync( + HttpRequest request, HttpResponse.BodyHandler handler) { + CompletableFuture> future = new CompletableFuture<>(); + future.completeExceptionally(expectedError); + return future; + } + }; + + new JdkA2AHttpClient(fakeClient) + .createGet() + .url("http://example.com/sse") + .getAsyncSSE(event -> {}, e -> { capturedError.set(e); errorLatch.countDown(); }, () -> {}); + + assertTrue(errorLatch.await(5, TimeUnit.SECONDS)); + assertNotNull(capturedError.get()); + assertEquals(expectedError, capturedError.get()); + } + + private abstract static class StubHttpClient extends HttpClient { + @Override public Optional cookieHandler() { return Optional.empty(); } + @Override public Optional connectTimeout() { return Optional.empty(); } + @Override public HttpClient.Redirect followRedirects() { return HttpClient.Redirect.NORMAL; } + @Override public Optional proxy() { return Optional.empty(); } + @Override public SSLContext sslContext() { throw new UnsupportedOperationException(); } + @Override public SSLParameters sslParameters() { return new SSLParameters(); } + @Override public Optional authenticator() { return Optional.empty(); } + @Override public HttpClient.Version version() { return HttpClient.Version.HTTP_1_1; } + @Override public Optional executor() { return Optional.empty(); } + @Override public HttpResponse send(HttpRequest r, HttpResponse.BodyHandler h) { throw new UnsupportedOperationException(); } + @Override public CompletableFuture> sendAsync(HttpRequest r, HttpResponse.BodyHandler h, HttpResponse.PushPromiseHandler p) { return sendAsync(r, h); } + @Override public WebSocket.Builder newWebSocketBuilder() { throw new UnsupportedOperationException(); } + } + + private static final class FakeResponseInfo implements HttpResponse.ResponseInfo { + private final int statusCode; + private final HttpHeaders headers; + + FakeResponseInfo(int statusCode, String contentType) { + this.statusCode = statusCode; + this.headers = HttpHeaders.of(Map.of("Content-Type", List.of(contentType)), (k, v) -> true); + } + + @Override public int statusCode() { return statusCode; } + @Override public HttpHeaders headers() { return headers; } + @Override public HttpClient.Version version() { return HttpClient.Version.HTTP_1_1; } + } + + private static final class NoOpSubscription implements Flow.Subscription { + @Override public void request(long n) {} + @Override public void cancel() {} + } + private static final class TrackingProxySelector extends ProxySelector { private final AtomicInteger selectCount = new AtomicInteger(); diff --git a/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java b/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java index aaf90f42e..74ac5acf6 100644 --- a/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java +++ b/reference/multiversion-jsonrpc/src/main/java/org/a2aproject/sdk/server/multiversion/jsonrpc/MultiVersionJSONRPCRoutes.java @@ -38,7 +38,10 @@ void setupRoutes(@Observes Router router) { .handler(BodyHandler.create()) .blockingHandler(ctx -> { try { - vertxSecurityHelper.runInRequestContext(ctx, () -> { + // JSON-RPC multiplexes streaming and non-streaming methods over a single + // endpoint, so we always use deferred CDI context destruction — matching + // the single-version A2AServerRoutes behavior. + vertxSecurityHelper.runInRequestContextDeferred(ctx, () -> { String version = VersionRouter.resolveVersion(ctx); String body = ctx.body().asString(); @@ -53,7 +56,8 @@ void setupRoutes(@Observes Router router) { null); } }); - } catch (UnauthorizedException | ForbiddenException e) {vertxSecurityHelper.handleAuthError(ctx, e); + } catch (UnauthorizedException | ForbiddenException e) { + vertxSecurityHelper.handleAuthError(ctx, e); } catch (A2AError e) { ctx.response() .setStatusCode(200) diff --git a/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java b/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java index 05a39a480..816f2639a 100644 --- a/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java +++ b/reference/multiversion-rest/src/main/java/org/a2aproject/sdk/server/multiversion/rest/MultiVersionRestRoutes.java @@ -42,16 +42,16 @@ void setupRoutes(@Observes @Priority(5) Router router) { router.postWithRegex("^\\/v1\\/message:send$") .order(-1) .handler(BodyHandler.create()) - .blockingHandler(versionDispatch( + .blockingHandler(versionDispatch(false, MultiVersionRestRoutes::bridgeTenant, (body, ctx) -> v10Routes.sendMessage(body, ctx), (body, ctx) -> v03Routes.sendMessage(body, ctx)), false); - // POST /v1/message:stream + // POST /v1/message:stream (deferred CDI context destruction) router.postWithRegex("^\\/v1\\/message:stream$") .order(-1) .handler(BodyHandler.create()) - .blockingHandler(versionDispatch( + .blockingHandler(versionDispatch(true, MultiVersionRestRoutes::bridgeTenant, (body, ctx) -> v10Routes.sendMessageStreaming(body, ctx), (body, ctx) -> v03Routes.sendMessageStreaming(body, ctx)), false); @@ -59,7 +59,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { // GET /v1/tasks/{taskId} router.getWithRegex("^\\/v1\\/tasks\\/(?[^:^/]+)$") .order(-1) - .blockingHandler(versionDispatchNoBody( + .blockingHandler(versionDispatchNoBody(false, ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.getTask(ctx), ctx -> v03Routes.getTask(ctx)), false); @@ -68,15 +68,15 @@ void setupRoutes(@Observes @Priority(5) Router router) { router.postWithRegex("^\\/v1\\/tasks\\/(?[^/]+):cancel$") .order(-1) .handler(BodyHandler.create()) - .blockingHandler(versionDispatch( + .blockingHandler(versionDispatch(false, ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, (body, ctx) -> v10Routes.cancelTask(body, ctx), (body, ctx) -> v03Routes.cancelTask(ctx)), false); - // POST /v1/tasks/{taskId}:subscribe + // POST /v1/tasks/{taskId}:subscribe (deferred CDI context destruction) router.postWithRegex("^\\/v1\\/tasks\\/(?[^/]+):subscribe$") .order(-1) - .blockingHandler(versionDispatchNoBody( + .blockingHandler(versionDispatchNoBody(true, ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.subscribeToTask(ctx), ctx -> v03Routes.resubscribeTask(ctx)), false); @@ -85,7 +85,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { router.postWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs$") .order(-1) .handler(BodyHandler.create()) - .blockingHandler(versionDispatch( + .blockingHandler(versionDispatch(false, ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, (body, ctx) -> v10Routes.createTaskPushNotificationConfiguration(body, ctx), (body, ctx) -> v03Routes.setTaskPushNotificationConfiguration(body, ctx)), false); @@ -93,7 +93,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { // GET /v1/tasks/{taskId}/pushNotificationConfigs/{configId} router.getWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs\\/(?[^\\/]+)") .order(-1) - .blockingHandler(versionDispatchNoBody( + .blockingHandler(versionDispatchNoBody(false, ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.getTaskPushNotificationConfiguration(ctx), ctx -> v03Routes.getTaskPushNotificationConfiguration(ctx)), false); @@ -101,7 +101,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { // GET /v1/tasks/{taskId}/pushNotificationConfigs router.getWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs\\/?$") .order(-1) - .blockingHandler(versionDispatchNoBody( + .blockingHandler(versionDispatchNoBody(false, ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.listTaskPushNotificationConfigurations(ctx), ctx -> v03Routes.listTaskPushNotificationConfigurations(ctx)), false); @@ -109,7 +109,7 @@ void setupRoutes(@Observes @Priority(5) Router router) { // DELETE /v1/tasks/{taskId}/pushNotificationConfigs/{configId} router.deleteWithRegex("^\\/v1\\/tasks\\/(?[^/]+)\\/pushNotificationConfigs\\/(?[^/]+)") .order(-1) - .blockingHandler(versionDispatchNoBody( + .blockingHandler(versionDispatchNoBody(false, ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); }, ctx -> v10Routes.deleteTaskPushNotificationConfiguration(ctx), ctx -> v03Routes.deleteTaskPushNotificationConfiguration(ctx)), false); @@ -142,13 +142,17 @@ private static void bridgeTaskId(RoutingContext ctx) { } } + /** + * @param deferContextDestruction if true, defers CDI request context destruction until the SSE response completes + */ private io.vertx.core.Handler versionDispatch( + boolean deferContextDestruction, Consumer paramBridger, BiConsumer v10Handler, BiConsumer v03Handler) { return ctx -> { try { - vertxSecurityHelper.runInRequestContext(ctx, () -> { + Runnable task = () -> { String version = VersionRouter.resolveVersion(ctx); paramBridger.accept(ctx); String body = ctx.body().asString(); @@ -165,7 +169,8 @@ private io.vertx.core.Handler versionDispatch( "Protocol version '" + version + "' is not supported. Supported versions: [1.0, 0.3]", null); } - }); + }; + runInContext(ctx, task, deferContextDestruction); } catch (UnauthorizedException | ForbiddenException e) { vertxSecurityHelper.handleAuthError(ctx, e); } catch (A2AError e) { @@ -176,13 +181,17 @@ private io.vertx.core.Handler versionDispatch( }; } + /** + * @param deferContextDestruction if true, defers CDI request context destruction until the SSE response completes + */ private io.vertx.core.Handler versionDispatchNoBody( + boolean deferContextDestruction, Consumer paramBridger, Consumer v10Handler, Consumer v03Handler) { return ctx -> { try { - vertxSecurityHelper.runInRequestContext(ctx, () -> { + Runnable task = () -> { String version = VersionRouter.resolveVersion(ctx); paramBridger.accept(ctx); if (VersionRouter.isV10(version)) { @@ -195,7 +204,8 @@ private io.vertx.core.Handler versionDispatchNoBody( "Protocol version '" + version + "' is not supported. Supported versions: [1.0, 0.3]", null); } - }); + }; + runInContext(ctx, task, deferContextDestruction); } catch (UnauthorizedException | ForbiddenException e) { vertxSecurityHelper.handleAuthError(ctx, e); } catch (A2AError e) { @@ -206,6 +216,14 @@ private io.vertx.core.Handler versionDispatchNoBody( }; } + private void runInContext(RoutingContext ctx, Runnable task, boolean deferContextDestruction) { + if (deferContextDestruction) { + vertxSecurityHelper.runInRequestContextDeferred(ctx, task); + } else { + vertxSecurityHelper.runInRequestContext(ctx, task); + } + } + private static void sendA2AErrorResponse(RoutingContext ctx, A2AError error) { A2AErrorCodes errorCode = A2AErrorCodes.fromCode(error.getCode()); int httpStatus = errorCode != null ? errorCode.httpCode() : 400;