Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ public void testMultiInstanceEventReplication() throws Exception {
getClient1().subscribeToTask(new TaskIdParams(taskId), List.of(app1Subscriber), app1ErrorHandler);
getClient2().subscribeToTask(new TaskIdParams(taskId), List.of(app2Subscriber), app2ErrorHandler);

// Wait for subscriptions to be established - at least one event should arrive on each
// Wait for subscriptions to be established - initial TaskEvent should arrive on each
await()
.atMost(Duration.ofSeconds(10))
.pollInterval(Duration.ofMillis(500))
.until(() -> app1EventCount.get() >= 1 && app2EventCount.get() >= 1);
.until(() -> app1ReceivedInitialTask.get() && app2ReceivedInitialTask.get());

// Step 3: Send message on app1 (should generate TaskArtifactUpdateEvent)
int app1BeforeMsg1 = app1EventCount.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,15 @@ public void onNext(String item) {
@Override
public void onError(Throwable throwable) {
if (errorNotified.compareAndSet(false, true)) {
errorConsumer.accept(throwable);
}
if (subscription != null) {
subscription.cancel();
if (!(throwable instanceof java.util.concurrent.CancellationException)) {
errorConsumer.accept(throwable);
}
}
}

@Override
public void onComplete() {
if (!errorNotified.get()) {
if (errorNotified.compareAndSet(false, true)) {
if (useSseParser.get()) {
sseParser.flush();
} else {
Expand All @@ -196,9 +195,6 @@ public void onComplete() {
}
completeRunnable.run();
}
if (subscription != null) {
subscription.cancel();
}
}
};

Expand Down Expand Up @@ -251,7 +247,9 @@ public void onComplete() {
.<Void>handle((response, throwable) -> {
if (throwable != null && errorNotified.compareAndSet(false, true)) {
Throwable cause = throwable.getCause() != null ? throwable.getCause() : throwable;
errorConsumer.accept(cause);
if (!(cause instanceof java.util.concurrent.CancellationException)) {
errorConsumer.accept(cause);
}
}
return null;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,39 @@
import org.mockserver.integration.ClientAndServer;

import java.io.IOException;
import java.net.Authenticator;
import java.net.CookieHandler;
import java.net.Proxy;
import java.net.ProxySelector;
import java.net.SocketAddress;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpHeaders;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.WebSocket;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockserver.model.HttpRequest.request;
import static org.mockserver.model.HttpResponse.response;

Expand Down Expand Up @@ -74,6 +97,120 @@ public void testConstructorRejectsNullHttpClient() {
assertThrows(IllegalArgumentException.class, () -> new JdkA2AHttpClient(null), "foo");
}

@Test
public void testCancellationExceptionViaSubscriberOnErrorIsNotPropagated() throws Exception {
AtomicReference<Throwable> capturedError = new AtomicReference<>();
AtomicBoolean completed = new AtomicBoolean(false);
CountDownLatch errorPathReached = new CountDownLatch(1);

HttpClient fakeClient = new StubHttpClient() {
@Override
public <T> CompletableFuture<HttpResponse<T>> sendAsync(
HttpRequest request, HttpResponse.BodyHandler<T> handler) {
HttpResponse.BodySubscriber<T> bodySubscriber =
handler.apply(new FakeResponseInfo(200, "text/plain"));
bodySubscriber.onSubscribe(new NoOpSubscription());
bodySubscriber.onError(new CancellationException());
errorPathReached.countDown();
return new CompletableFuture<>(); // never completes
}
};

new JdkA2AHttpClient(fakeClient)
.createGet()
.url("http://example.com/sse")
.getAsyncSSE(event -> {}, e -> capturedError.set(e), () -> completed.set(true));

assertTrue(errorPathReached.await(5, TimeUnit.SECONDS));
assertNull(capturedError.get(), "CancellationException should not reach the error consumer");
assertFalse(completed.get(), "Complete handler must not be called after cancellation");
}

@Test
public void testCancellationExceptionViaFutureFailureIsNotPropagated() throws Exception {
AtomicReference<Throwable> capturedError = new AtomicReference<>();
AtomicBoolean completed = new AtomicBoolean(false);

HttpClient fakeClient = new StubHttpClient() {
@Override
public <T> CompletableFuture<HttpResponse<T>> sendAsync(
HttpRequest request, HttpResponse.BodyHandler<T> handler) {
CompletableFuture<HttpResponse<T>> future = new CompletableFuture<>();
future.completeExceptionally(new CancellationException());
return future;
}
};

CompletableFuture<Void> result = new JdkA2AHttpClient(fakeClient)
.createGet()
.url("http://example.com/sse")
.getAsyncSSE(event -> {}, e -> capturedError.set(e), () -> completed.set(true));

result.get(5, TimeUnit.SECONDS);
assertNull(capturedError.get(), "CancellationException should not reach the error consumer");
assertFalse(completed.get(), "Complete handler must not be called after cancellation");
}

@Test
public void testRealErrorsAreStillPropagatedToErrorConsumer() throws Exception {
AtomicReference<Throwable> capturedError = new AtomicReference<>();
CountDownLatch errorLatch = new CountDownLatch(1);
IOException expectedError = new IOException("connection refused");

HttpClient fakeClient = new StubHttpClient() {
@Override
public <T> CompletableFuture<HttpResponse<T>> sendAsync(
HttpRequest request, HttpResponse.BodyHandler<T> handler) {
CompletableFuture<HttpResponse<T>> future = new CompletableFuture<>();
future.completeExceptionally(expectedError);
return future;
}
};

new JdkA2AHttpClient(fakeClient)
.createGet()
.url("http://example.com/sse")
.getAsyncSSE(event -> {}, e -> { capturedError.set(e); errorLatch.countDown(); }, () -> {});

assertTrue(errorLatch.await(5, TimeUnit.SECONDS));
assertNotNull(capturedError.get());
assertEquals(expectedError, capturedError.get());
}

private abstract static class StubHttpClient extends HttpClient {
@Override public Optional<CookieHandler> cookieHandler() { return Optional.empty(); }
@Override public Optional<Duration> connectTimeout() { return Optional.empty(); }
@Override public HttpClient.Redirect followRedirects() { return HttpClient.Redirect.NORMAL; }
@Override public Optional<ProxySelector> proxy() { return Optional.empty(); }
@Override public SSLContext sslContext() { throw new UnsupportedOperationException(); }
@Override public SSLParameters sslParameters() { return new SSLParameters(); }
@Override public Optional<Authenticator> authenticator() { return Optional.empty(); }
@Override public HttpClient.Version version() { return HttpClient.Version.HTTP_1_1; }
@Override public Optional<Executor> executor() { return Optional.empty(); }
@Override public <T> HttpResponse<T> send(HttpRequest r, HttpResponse.BodyHandler<T> h) { throw new UnsupportedOperationException(); }
@Override public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest r, HttpResponse.BodyHandler<T> h, HttpResponse.PushPromiseHandler<T> p) { return sendAsync(r, h); }
@Override public WebSocket.Builder newWebSocketBuilder() { throw new UnsupportedOperationException(); }
}

private static final class FakeResponseInfo implements HttpResponse.ResponseInfo {
private final int statusCode;
private final HttpHeaders headers;

FakeResponseInfo(int statusCode, String contentType) {
this.statusCode = statusCode;
this.headers = HttpHeaders.of(Map.of("Content-Type", List.of(contentType)), (k, v) -> true);
}

@Override public int statusCode() { return statusCode; }
@Override public HttpHeaders headers() { return headers; }
@Override public HttpClient.Version version() { return HttpClient.Version.HTTP_1_1; }
}

private static final class NoOpSubscription implements Flow.Subscription {
@Override public void request(long n) {}
@Override public void cancel() {}
}

private static final class TrackingProxySelector extends ProxySelector {
private final AtomicInteger selectCount = new AtomicInteger();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ void setupRoutes(@Observes Router router) {
.handler(BodyHandler.create())
.blockingHandler(ctx -> {
try {
vertxSecurityHelper.runInRequestContext(ctx, () -> {
// JSON-RPC multiplexes streaming and non-streaming methods over a single
// endpoint, so we always use deferred CDI context destruction — matching
// the single-version A2AServerRoutes behavior.
vertxSecurityHelper.runInRequestContextDeferred(ctx, () -> {
String version = VersionRouter.resolveVersion(ctx);
String body = ctx.body().asString();

Expand All @@ -53,7 +56,8 @@ void setupRoutes(@Observes Router router) {
null);
}
});
} catch (UnauthorizedException | ForbiddenException e) {vertxSecurityHelper.handleAuthError(ctx, e);
} catch (UnauthorizedException | ForbiddenException e) {
vertxSecurityHelper.handleAuthError(ctx, e);
} catch (A2AError e) {
ctx.response()
.setStatusCode(200)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,24 @@ void setupRoutes(@Observes @Priority(5) Router router) {
router.postWithRegex("^\\/v1\\/message:send$")
.order(-1)
.handler(BodyHandler.create())
.blockingHandler(versionDispatch(
.blockingHandler(versionDispatch(false,
MultiVersionRestRoutes::bridgeTenant,
(body, ctx) -> v10Routes.sendMessage(body, ctx),
(body, ctx) -> v03Routes.sendMessage(body, ctx)), false);

// POST /v1/message:stream
// POST /v1/message:stream (deferred CDI context destruction)
router.postWithRegex("^\\/v1\\/message:stream$")
.order(-1)
.handler(BodyHandler.create())
.blockingHandler(versionDispatch(
.blockingHandler(versionDispatch(true,
MultiVersionRestRoutes::bridgeTenant,
(body, ctx) -> v10Routes.sendMessageStreaming(body, ctx),
(body, ctx) -> v03Routes.sendMessageStreaming(body, ctx)), false);

// GET /v1/tasks/{taskId}
router.getWithRegex("^\\/v1\\/tasks\\/(?<taskId>[^:^/]+)$")
.order(-1)
.blockingHandler(versionDispatchNoBody(
.blockingHandler(versionDispatchNoBody(false,
ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); },
ctx -> v10Routes.getTask(ctx),
ctx -> v03Routes.getTask(ctx)), false);
Expand All @@ -68,15 +68,15 @@ void setupRoutes(@Observes @Priority(5) Router router) {
router.postWithRegex("^\\/v1\\/tasks\\/(?<taskId>[^/]+):cancel$")
.order(-1)
.handler(BodyHandler.create())
.blockingHandler(versionDispatch(
.blockingHandler(versionDispatch(false,
ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); },
(body, ctx) -> v10Routes.cancelTask(body, ctx),
(body, ctx) -> v03Routes.cancelTask(ctx)), false);

// POST /v1/tasks/{taskId}:subscribe
// POST /v1/tasks/{taskId}:subscribe (deferred CDI context destruction)
router.postWithRegex("^\\/v1\\/tasks\\/(?<taskId>[^/]+):subscribe$")
.order(-1)
.blockingHandler(versionDispatchNoBody(
.blockingHandler(versionDispatchNoBody(true,
ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); },
ctx -> v10Routes.subscribeToTask(ctx),
ctx -> v03Routes.resubscribeTask(ctx)), false);
Expand All @@ -85,31 +85,31 @@ void setupRoutes(@Observes @Priority(5) Router router) {
router.postWithRegex("^\\/v1\\/tasks\\/(?<taskId>[^/]+)\\/pushNotificationConfigs$")
.order(-1)
.handler(BodyHandler.create())
.blockingHandler(versionDispatch(
.blockingHandler(versionDispatch(false,
ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); },
(body, ctx) -> v10Routes.createTaskPushNotificationConfiguration(body, ctx),
(body, ctx) -> v03Routes.setTaskPushNotificationConfiguration(body, ctx)), false);

// GET /v1/tasks/{taskId}/pushNotificationConfigs/{configId}
router.getWithRegex("^\\/v1\\/tasks\\/(?<taskId>[^/]+)\\/pushNotificationConfigs\\/(?<configId>[^\\/]+)")
.order(-1)
.blockingHandler(versionDispatchNoBody(
.blockingHandler(versionDispatchNoBody(false,
ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); },
ctx -> v10Routes.getTaskPushNotificationConfiguration(ctx),
ctx -> v03Routes.getTaskPushNotificationConfiguration(ctx)), false);

// GET /v1/tasks/{taskId}/pushNotificationConfigs
router.getWithRegex("^\\/v1\\/tasks\\/(?<taskId>[^/]+)\\/pushNotificationConfigs\\/?$")
.order(-1)
.blockingHandler(versionDispatchNoBody(
.blockingHandler(versionDispatchNoBody(false,
ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); },
ctx -> v10Routes.listTaskPushNotificationConfigurations(ctx),
ctx -> v03Routes.listTaskPushNotificationConfigurations(ctx)), false);

// DELETE /v1/tasks/{taskId}/pushNotificationConfigs/{configId}
router.deleteWithRegex("^\\/v1\\/tasks\\/(?<taskId>[^/]+)\\/pushNotificationConfigs\\/(?<configId>[^/]+)")
.order(-1)
.blockingHandler(versionDispatchNoBody(
.blockingHandler(versionDispatchNoBody(false,
ctx -> { bridgeTenant(ctx); bridgeTaskId(ctx); },
ctx -> v10Routes.deleteTaskPushNotificationConfiguration(ctx),
ctx -> v03Routes.deleteTaskPushNotificationConfiguration(ctx)), false);
Expand Down Expand Up @@ -142,13 +142,17 @@ private static void bridgeTaskId(RoutingContext ctx) {
}
}

/**
* @param deferContextDestruction if true, defers CDI request context destruction until the SSE response completes
*/
private io.vertx.core.Handler<RoutingContext> versionDispatch(
boolean deferContextDestruction,
Consumer<RoutingContext> paramBridger,
BiConsumer<String, RoutingContext> v10Handler,
BiConsumer<String, RoutingContext> v03Handler) {
return ctx -> {
try {
vertxSecurityHelper.runInRequestContext(ctx, () -> {
Runnable task = () -> {
String version = VersionRouter.resolveVersion(ctx);
paramBridger.accept(ctx);
String body = ctx.body().asString();
Expand All @@ -165,7 +169,8 @@ private io.vertx.core.Handler<RoutingContext> versionDispatch(
"Protocol version '" + version + "' is not supported. Supported versions: [1.0, 0.3]",
null);
}
});
};
runInContext(ctx, task, deferContextDestruction);
} catch (UnauthorizedException | ForbiddenException e) {
vertxSecurityHelper.handleAuthError(ctx, e);
} catch (A2AError e) {
Expand All @@ -176,13 +181,17 @@ private io.vertx.core.Handler<RoutingContext> versionDispatch(
};
}

/**
* @param deferContextDestruction if true, defers CDI request context destruction until the SSE response completes
*/
private io.vertx.core.Handler<RoutingContext> versionDispatchNoBody(
boolean deferContextDestruction,
Consumer<RoutingContext> paramBridger,
Consumer<RoutingContext> v10Handler,
Consumer<RoutingContext> v03Handler) {
return ctx -> {
try {
vertxSecurityHelper.runInRequestContext(ctx, () -> {
Runnable task = () -> {
String version = VersionRouter.resolveVersion(ctx);
paramBridger.accept(ctx);
if (VersionRouter.isV10(version)) {
Expand All @@ -195,7 +204,8 @@ private io.vertx.core.Handler<RoutingContext> versionDispatchNoBody(
"Protocol version '" + version + "' is not supported. Supported versions: [1.0, 0.3]",
null);
}
});
};
runInContext(ctx, task, deferContextDestruction);
} catch (UnauthorizedException | ForbiddenException e) {
vertxSecurityHelper.handleAuthError(ctx, e);
} catch (A2AError e) {
Expand All @@ -206,6 +216,14 @@ private io.vertx.core.Handler<RoutingContext> versionDispatchNoBody(
};
}

private void runInContext(RoutingContext ctx, Runnable task, boolean deferContextDestruction) {
if (deferContextDestruction) {
vertxSecurityHelper.runInRequestContextDeferred(ctx, task);
} else {
vertxSecurityHelper.runInRequestContext(ctx, task);
}
}

private static void sendA2AErrorResponse(RoutingContext ctx, A2AError error) {
A2AErrorCodes errorCode = A2AErrorCodes.fromCode(error.getCode());
int httpStatus = errorCode != null ? errorCode.httpCode() : 400;
Expand Down
Loading