diff --git a/README.md b/README.md index 07b1f8c34..eb0d53585 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Weaviate Java client Weaviate logo +# Weaviate Java client Weaviate logo [![Build Status](https://github.com/weaviate/java-client/actions/workflows/.github/workflows/test.yaml/badge.svg?branch=main)](https://github.com/weaviate/java-client/actions/workflows/.github/workflows/test.yaml) diff --git a/src/it/java/io/weaviate/integration/RbacITest.java b/src/it/java/io/weaviate/integration/RbacITest.java index 57d6ca7e7..55a20ece7 100644 --- a/src/it/java/io/weaviate/integration/RbacITest.java +++ b/src/it/java/io/weaviate/integration/RbacITest.java @@ -84,15 +84,6 @@ public void test_roles_Lifecycle() throws IOException { Permission.groups("my-group", GroupType.OIDC, GroupsPermission.Action.READ)); }); - requireAtLeast(Weaviate.Version.V132, () -> { - permissions.add( - Permission.aliases("ThingsAlias", myCollection, AliasesPermission.Action.CREATE)); - }); - requireAtLeast(Weaviate.Version.V133, () -> { - permissions.add( - Permission.groups("my-group", GroupType.OIDC, GroupsPermission.Action.READ)); - }); - // Act: create role client.roles.create(nsRole, permissions); diff --git a/src/main/java/io/weaviate/client6/v1/api/Authentication.java b/src/main/java/io/weaviate/client6/v1/api/Authentication.java index 1c8776c1c..eb6a83ec6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Authentication.java +++ b/src/main/java/io/weaviate/client6/v1/api/Authentication.java @@ -56,6 +56,26 @@ public static Authentication resourceOwnerPassword(String username, String passw }; } + /** + * Authenticate using Resource Owner Password Credentials authorization grant. + * + * @param clientSecret Client secret. + * @param username Resource owner username. + * @param password Resource owner password. + * @param scopes Client scopes. + * + * @return Authentication provider. + * @throws WeaviateOAuthException if an error occurred at any point of the token + * exchange process. + */ + public static Authentication resourceOwnerPasswordCredentials(String clientSecret, String username, String password, + List scopes) { + return transport -> { + OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes).withScopes("offline_access"); + return TokenProvider.resourceOwnerPasswordCredentials(oidc, clientSecret, username, password); + }; + } + /** * Authenticate using Client Credentials authorization grant. * diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index 33baabb63..c3c16623e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -9,6 +9,7 @@ import io.weaviate.client6.v1.internal.BuildInfo; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.TransportOptions; @@ -24,7 +25,8 @@ public record Config( Map headers, Authentication authentication, TrustManagerFactory trustManagerFactory, - Timeout timeout) { + Timeout timeout, + Proxy proxy) { public static Config of(Function> fn) { return fn.apply(new Custom()).build(); @@ -40,7 +42,8 @@ private Config(Builder builder) { builder.headers, builder.authentication, builder.trustManagerFactory, - builder.timeout); + builder.timeout, + builder.proxy); } RestTransportOptions restTransportOptions() { @@ -48,7 +51,7 @@ RestTransportOptions restTransportOptions() { } RestTransportOptions restTransportOptions(TokenProvider tokenProvider) { - return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout); + return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout, proxy); } GrpcChannelOptions grpcTransportOptions() { @@ -56,7 +59,7 @@ GrpcChannelOptions grpcTransportOptions() { } GrpcChannelOptions grpcTransportOptions(TokenProvider tokenProvider) { - return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout); + return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout, proxy); } private abstract static class Builder> implements ObjectBuilder { @@ -70,6 +73,7 @@ private abstract static class Builder> implements O protected TrustManagerFactory trustManagerFactory; protected Timeout timeout = new Timeout(); protected Map headers = new HashMap<>(); + protected Proxy proxy; /** * Set URL scheme. Subclasses may increase the visibility of this method to @@ -175,6 +179,15 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) { return (SelfT) this; } + /** + * Set proxy for all requests. + */ + @SuppressWarnings("unchecked") + public SelfT proxy(Proxy proxy) { + this.proxy = proxy; + return (SelfT) this; + } + /** * Weaviate will use the URL in this header to call Weaviate Embeddings * Service if an appropriate vectorizer is configured for collection. diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java index 63e14c3bf..6150ab4d5 100644 --- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java @@ -63,14 +63,13 @@ public class WeaviateClient implements AutoCloseable { public final WeaviateClusterClient cluster; public WeaviateClient(Config config) { - RestTransportOptions restOpt; + RestTransportOptions restOpt = config.restTransportOptions(); GrpcChannelOptions grpcOpt; if (config.authentication() == null) { - restOpt = config.restTransportOptions(); grpcOpt = config.grpcTransportOptions(); } else { TokenProvider tokenProvider; - try (final var noAuthRest = new DefaultRestTransport(config.restTransportOptions())) { + try (final var noAuthRest = new DefaultRestTransport(restOpt)) { tokenProvider = config.authentication().getTokenProvider(noAuthRest); } catch (Exception e) { // Generally exceptions are caught in TokenProvider internals. @@ -126,6 +125,10 @@ public WeaviateClient(Config config) { this.config = config; } + public Config getConfig() { + return config; + } + /** * Create {@link WeaviateClientAsync} with identical configurations. * It is a shorthand for: diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java index c0fca9bbc..e6101f682 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java @@ -75,7 +75,7 @@ public CollectionHandle> use( return use(CollectionDescriptor.ofMap(collectionName), fn); } - private CollectionHandle use(CollectionDescriptor collection, + public CollectionHandle use(CollectionDescriptor collection, Function> fn) { return new CollectionHandle<>(restTransport, grpcTransport, collection, CollectionHandleDefaults.of(fn)); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index e5fb263d0..27755645e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -335,7 +335,21 @@ public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException { @Override public void close() throws IOException { boolean closedBefore = closed; - closed = true; + + // Update the value atomically to make sure shutdownNow + // does not unnecessarily interrupt this thread. + synchronized (this) { + closed = true; + } + + // If we'd been interrupted by shutdownNow, closing would've been + // completed exceptionally prior to that. If that's not the case + // but the current thread is interrupted, then we must propagate + // the interrupt. But first, we should dispose of the services. + if (Thread.interrupted() && !closing.isCompletedExceptionally()) { + shutdownExecutors(); + Thread.currentThread().interrupt(); + } log.atDebug() .addKeyValue("closed_before", closedBefore) @@ -409,16 +423,26 @@ private void shutdownNow(Exception e) { send.cancel(true); } - if (!closed) { - // Since shutdownNow is never triggered by the "main" thread, - // it may be blocked on trying to add to the queue. While batch - // context is active, we own this thread and may interrupt it. - log.atDebug() - .addKeyValue("thread", Thread::currentThread) - .addKeyValue("closed", closed) - .log("Interrupt parent thread"); - parent.interrupt(); + // Since shutdownNow is never triggered by the "main" thread, + // it may be blocked on trying to add to the queue. While batch + // context is active, we own this thread and may interrupt it. + // We must be able to guarantee that shutdownNow never interrupts + // an in-progress close and we also don't want to potentially block + // the gRPC thread on which shutdownNow may be executing; we use + // the doubly-checked locking pattern to helps us achieve that. + if (closed) { + return; } + synchronized (this) { + if (!closed) { + log.atDebug() + .addKeyValue("thread", Thread::currentThread) + .addKeyValue("closed", closed) + .log("Interrupt parent thread"); + parent.interrupt(); + } + } + } private void shutdownExecutors() { diff --git a/src/main/java/io/weaviate/client6/v1/internal/Proxy.java b/src/main/java/io/weaviate/client6/v1/internal/Proxy.java new file mode 100644 index 000000000..5637cfeae --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/Proxy.java @@ -0,0 +1,15 @@ +package io.weaviate.client6.v1.internal; + +import javax.annotation.Nullable; + +public record Proxy( + String scheme, + String host, + int port, + @Nullable String username, + @Nullable String password +) { + public Proxy(String host, int port) { + this("http", host, port, null, null); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java index 7b28a5e7a..e88f48563 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java @@ -141,6 +141,24 @@ public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String userna return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY)); } + /** + * Create a TokenProvider that uses Resource Owner Password Credentials authorization grant. + * + * @param oidc OIDC config. + * @param clientSecret Client secret. + * @param username Resource owner username. + * @param password Resource owner password. + * + * @return Internal TokenProvider implementation. + * @throws WeaviateOAuthException if an error occurred at any point of the token + * exchange process. + */ + public static TokenProvider resourceOwnerPasswordCredentials(OidcConfig oidc, String clientSecret, String username, + String password) { + final var passwordGrant = NimbusTokenProvider.resouceOwnerPasswordCredentials(oidc, clientSecret, username, password); + return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY)); + } + /** * Create a TokenProvider that uses Client Credentials authorization grant. * diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java index 06c0b6c15..8a5a1a789 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java @@ -11,9 +11,10 @@ public abstract class TransportOptions { protected final H headers; protected final TrustManagerFactory trustManagerFactory; protected final Timeout timeout; + protected final Proxy proxy; protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider, - TrustManagerFactory tmf, Timeout timeout) { + TrustManagerFactory tmf, Timeout timeout, Proxy proxy) { this.scheme = scheme; this.host = host; this.port = port; @@ -21,6 +22,7 @@ protected TransportOptions(String scheme, String host, int port, H headers, Toke this.headers = headers; this.timeout = timeout; this.trustManagerFactory = tmf; + this.proxy = proxy; } public boolean isSecure() { @@ -58,6 +60,11 @@ public TrustManagerFactory trustManagerFactory() { return this.trustManagerFactory; } + @Nullable + public Proxy proxy() { + return this.proxy; + } + /** * isWeaviateDomain returns true if the host matches weaviate.io, * semi.technology, or weaviate.cloud domain. @@ -73,4 +80,9 @@ public static boolean isGoogleCloudDomain(String host) { var lower = host.toLowerCase(); return lower.contains("gcp"); } + + @Nullable + public Proxy proxy() { + return this.proxy; + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index 385808ffc..f43c6f756 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -2,6 +2,8 @@ import static java.util.Objects.requireNonNull; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -12,7 +14,7 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; - +import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.ManagedChannel; import io.grpc.StatusRuntimeException; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; @@ -22,12 +24,19 @@ import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; import io.weaviate.client6.v1.api.WeaviateApiException; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest; +import javax.net.ssl.SSLException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + public final class DefaultGrpcTransport implements GrpcTransport { /** * ListenableFuture callbacks are executed @@ -92,7 +101,7 @@ public CompletableFuture perf var method = rpc.methodAsync(); var stub = applyTimeout(futureStub, rpc); var reply = method.apply(stub, message); - return toCompletableFuture(reply).thenApply(r -> rpc.unmarshal(r)); + return toCompletableFuture(reply).thenApply(rpc::unmarshal); } /** @@ -146,6 +155,27 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) channel.sslContext(sslCtx); } + if (transportOptions.proxy() != null) { + Proxy proxy = transportOptions.proxy(); + if ("http".equals(proxy.scheme()) || "https".equals(proxy.scheme())) { + final SocketAddress proxyAddress = new InetSocketAddress(proxy.host(), proxy.port()); + channel.proxyDetector(targetAddress -> { + if (targetAddress instanceof InetSocketAddress) { + HttpConnectProxiedSocketAddress.Builder builder = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress((InetSocketAddress) targetAddress); + + if (proxy.username() != null && proxy.password() != null) { + builder.setUsername(proxy.username()); + builder.setPassword(proxy.password()); + } + return builder.build(); + } + return null; + }); + } + } + channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); return channel.build(); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java index 96366cb5f..dee4bd1e3 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java @@ -6,6 +6,7 @@ import javax.net.ssl.TrustManagerFactory; import io.grpc.Metadata; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.TransportOptions; @@ -14,20 +15,20 @@ public class GrpcChannelOptions extends TransportOptions { private final OptionalInt maxMessageSize; public GrpcChannelOptions(String scheme, String host, int port, Map headers, - TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) { - this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout); + TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout, Proxy proxy) { + this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout, proxy); } private GrpcChannelOptions(String scheme, String host, int port, Metadata headers, - TokenProvider tokenProvider, TrustManagerFactory tmf, OptionalInt maxMessageSize, Timeout timeout) { - super(scheme, host, port, headers, tokenProvider, tmf, timeout); + TokenProvider tokenProvider, TrustManagerFactory tmf, OptionalInt maxMessageSize, Timeout timeout, Proxy proxy) { + super(scheme, host, port, headers, tokenProvider, tmf, timeout, proxy); this.maxMessageSize = maxMessageSize; } public GrpcChannelOptions withMaxMessageSize(int maxMessageSize) { return new GrpcChannelOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, OptionalInt.of(maxMessageSize), - timeout); + timeout, proxy); } public OptionalInt maxMessageSize() { diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java index 858fd4d82..75e22a4a6 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java @@ -1,5 +1,7 @@ package io.weaviate.client6.v1.internal.oidc; +import io.weaviate.client6.v1.internal.Proxy; + import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -11,16 +13,26 @@ public record OidcConfig( String clientId, String providerMetadata, - Set scopes) { + Set scopes, + Proxy proxy) { - public OidcConfig(String clientId, String providerMetadata, Set scopes) { + public OidcConfig(String clientId, String providerMetadata, Set scopes, Proxy proxy) { this.clientId = clientId; this.providerMetadata = providerMetadata; this.scopes = scopes != null ? Set.copyOf(scopes) : Collections.emptySet(); + this.proxy = proxy; + } + + public OidcConfig(String clientId, String providerMetadata, Set scopes) { + this(clientId, providerMetadata, scopes, null); } public OidcConfig(String clientId, String providerMetadata, List scopes) { - this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes)); + this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), null); + } + + public OidcConfig(String clientId, String providerMetadata, List scopes, Proxy proxy) { + this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), proxy); } /** Create a new OIDC config with extended scopes. */ @@ -31,6 +43,6 @@ public OidcConfig withScopes(String... scopes) { /** Create a new OIDC config with extended scopes. */ public OidcConfig withScopes(List scopes) { var newScopes = Stream.concat(this.scopes.stream(), scopes.stream()).collect(Collectors.toSet()); - return new OidcConfig(clientId, providerMetadata, newScopes); + return new OidcConfig(clientId, providerMetadata, newScopes, proxy); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java index cafcc1289..027ea60c5 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java @@ -1,17 +1,16 @@ package io.weaviate.client6.v1.internal.oidc; -import java.io.IOException; -import java.util.Collections; -import java.util.List; - import com.google.gson.annotations.SerializedName; - import io.weaviate.client6.v1.api.WeaviateOAuthException; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.ExternalEndpoint; import io.weaviate.client6.v1.internal.rest.RestTransport; import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; +import java.io.IOException; +import java.util.Collections; +import java.util.List; + public final class OidcUtils { /** Prevents public initialization. */ private OidcUtils() { @@ -28,7 +27,7 @@ private OidcUtils() { private static final Endpoint GET_PROVIDER_METADATA_ENDPOINT = new ExternalEndpoint<>( request -> "GET", request -> request, // URL is the request body. - requesf -> Collections.emptyMap(), + request -> Collections.emptyMap(), request -> null, (__, response) -> response); @@ -54,6 +53,6 @@ public static final OidcConfig getConfig(RestTransport transport) { throw new WeaviateOAuthException("fetch provider metadata", e); } - return new OidcConfig(openid.clientId(), providerMetadata, openid.scopes()); + return new OidcConfig(openid.clientId(), providerMetadata, openid.scopes(), transport.getProxy()); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java index ad12561ff..4d0c3c32f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java @@ -27,6 +27,21 @@ static Flow resourceOwnerPassword(String username, String password) { return () -> grant; // Reuse cached authorization grant } + static Flow resourceOwnerPasswordCredentials(String clientId, String clientSecret, String username, String password) { + return new Flow() { + private final AuthorizationGrant GRANT = new ResourceOwnerPasswordCredentialsGrant(username, new Secret(password)); + @Override + public AuthorizationGrant getAuthorizationGrant() { + return GRANT; + } + + @Override + public ClientAuthentication getClientAuthentication() { + return new ClientSecretPost(new ClientID(clientId), new Secret(clientSecret)); + } + }; + } + static Flow clientCredentials(String clientId, String clientSecret) { return new Flow() { private static final AuthorizationGrant GRANT = new ClientCredentialsGrant(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java index 4bef50e71..5c240e4f2 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.internal.oidc.nimbus; import java.io.IOException; +import java.net.Proxy; import javax.annotation.concurrent.NotThreadSafe; @@ -8,7 +9,7 @@ import com.nimbusds.oauth2.sdk.Scope; import com.nimbusds.oauth2.sdk.TokenRequest; import com.nimbusds.oauth2.sdk.id.ClientID; -import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import com.nimbusds.oauth2.sdk.as.AuthorizationServerMetadata; import com.nimbusds.openid.connect.sdk.token.OIDCTokens; import io.weaviate.client6.v1.api.WeaviateOAuthException; @@ -17,10 +18,11 @@ @NotThreadSafe public final class NimbusTokenProvider implements TokenProvider { - private final OIDCProviderMetadata metadata; + private final AuthorizationServerMetadata metadata; private final ClientID clientId; private final Scope scope; private final Flow flow; + private Proxy proxy; /** * Create a TokenProvider that uses Refresh Token authorization grant. @@ -51,6 +53,23 @@ public static NimbusTokenProvider resourceOwnerPassword(OidcConfig oidc, String return new NimbusTokenProvider(oidc, Flow.resourceOwnerPassword(username, password)); } + /** + * Create a TokenProvider that uses Resource Owner Password Credentials authorization grant. + * + * @param oidc OIDC config. + * @param clientSecret Client secret. + * @param username Resource owner username. + * @param password Resource owner password. + * + * @return A new instance of NimbusTokenProvider. Instances are never cached. + * @throws WeaviateOAuthException if an error occured at any point of the + * exchange process. + */ + public static NimbusTokenProvider resouceOwnerPasswordCredentials(OidcConfig oidc, String clientSecret, String username, + String password) { + return new NimbusTokenProvider(oidc, Flow.resourceOwnerPasswordCredentials(oidc.clientId(), clientSecret, username, password)); + } + /** * Create a TokenProvider that uses Client Credentials authorization grant. * @@ -70,6 +89,10 @@ private NimbusTokenProvider(OidcConfig oidc, Flow flow) { this.clientId = new ClientID(oidc.clientId()); this.scope = new Scope(oidc.scopes().toArray(String[]::new)); this.flow = flow; + var proxy = oidc.proxy(); + if (proxy != null) { + this.proxy = new java.net.Proxy(java.net.Proxy.Type.HTTP, new java.net.InetSocketAddress(proxy.host(), proxy.port())); + } } @Override @@ -83,6 +106,10 @@ public Token getToken() { : new TokenRequest(uri, clientAuth, grant, scope); var request = tokenRequest.toHTTPRequest(); + if (proxy != null) { + request.setProxy(proxy); + } + OIDCTokens tokens; try { var response = request.send(); @@ -116,9 +143,9 @@ public static ProviderMetadata parseProviderMetadata(String providerMetadata) { return new ProviderMetadata(metadata.getTokenEndpointURI()); } - private static OIDCProviderMetadata _parseProviderMetadata(String providerMetadata) { + private static AuthorizationServerMetadata _parseProviderMetadata(String providerMetadata) { try { - return OIDCProviderMetadata.parse(providerMetadata); + return AuthorizationServerMetadata.parse(providerMetadata); } catch (ParseException ex) { throw new WeaviateOAuthException("parse provider metadata: ", ex); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java index a77d9da29..b80880841 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java @@ -8,6 +8,7 @@ import javax.net.ssl.SSLContext; +import org.apache.hc.core5.http.HttpHost; import org.apache.hc.client5.http.async.methods.SimpleHttpRequest; import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; import org.apache.hc.client5.http.config.RequestConfig; @@ -31,6 +32,7 @@ import io.weaviate.client6.v1.api.WeaviateApiException; import io.weaviate.client6.v1.api.WeaviateTransportException; +import io.weaviate.client6.v1.internal.Proxy; public class DefaultRestTransport implements RestTransport { private final CloseableHttpClient httpClient; @@ -43,9 +45,9 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { this.transportOptions = transportOptions; // TODO: doesn't make sense to spin up both? - var httpClient = HttpClients.custom() + var httpClient = HttpClients.custom().useSystemProperties() .setDefaultHeaders(transportOptions.headers()); - var httpClientAsync = HttpAsyncClients.custom() + var httpClientAsync = HttpAsyncClients.custom().useSystemProperties() .setDefaultHeaders(transportOptions.headers()); // Apply custom SSL context @@ -68,6 +70,13 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { httpClientAsync.setConnectionManager(asyncManager); } + if (transportOptions.proxy() != null) { + Proxy proxy = transportOptions.proxy(); + HttpHost proxyHost = new HttpHost(proxy.scheme(), proxy.host(), proxy.port()); + httpClient.setProxy(proxyHost); + httpClientAsync.setProxy(proxyHost); + } + if (transportOptions.timeout() != null) { var config = RequestConfig.custom() .setResponseTimeout(transportOptions.timeout().querySeconds(), TimeUnit.SECONDS) @@ -111,8 +120,6 @@ private ClassicHttpRequest prepareClassicRequest(RequestT if (body != null) { req.setEntity(body, ContentType.APPLICATION_JSON); } - if (true) { - } return req.build(); } @@ -196,6 +203,10 @@ private ResponseT _handleResponse(Endpoint endpoint, S throw new WeaviateTransportException("Unhandled endpoint type " + endpoint.getClass().getSimpleName()); } + public Proxy getProxy() { + return transportOptions.proxy(); + } + @Override public void close() throws Exception { httpClient.close(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java index da26c9f12..62f20871c 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java @@ -3,6 +3,8 @@ import java.io.IOException; import java.util.concurrent.CompletableFuture; +import io.weaviate.client6.v1.internal.Proxy; + public interface RestTransport extends AutoCloseable { ResponseT performRequest(RequestT request, Endpoint endpoint) @@ -10,4 +12,6 @@ ResponseT performRequest(RequestT request, CompletableFuture performRequestAsync(RequestT request, Endpoint endpoint); + + Proxy getProxy(); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java index 5da4cdd5f..4635a9368 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java @@ -8,6 +8,7 @@ import org.apache.hc.core5.http.message.BasicHeader; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.TransportOptions; @@ -16,17 +17,21 @@ public final class RestTransportOptions extends TransportOptions headers, - TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout) { - super(scheme, host, port, buildHeaders(headers), tokenProvider, trust, timeout); + TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout, Proxy proxy) { + super(scheme, host, port, buildHeaders(headers), tokenProvider, trust, timeout, proxy); } private RestTransportOptions(String scheme, String host, int port, Collection headers, - TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout) { - super(scheme, host, port, headers, tokenProvider, trust, timeout); + TokenProvider tokenProvider, TrustManagerFactory trust, Timeout timeout, Proxy proxy) { + super(scheme, host, port, headers, tokenProvider, trust, timeout, proxy); + } + + public RestTransportOptions(String http, String localhost, Integer localPort, Map headers, Object tokenProvider, Object trust, Timeout timeout) { + super(http, localhost, localPort, buildHeaders(headers), (TokenProvider) tokenProvider, (TrustManagerFactory) trust, timeout, null); } public final RestTransportOptions withTimeout(Timeout timeout) { - return new RestTransportOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, timeout); + return new RestTransportOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory, timeout, proxy); } private static final Collection buildHeaders(Map headers) { diff --git a/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java b/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java index 54472875e..124789280 100644 --- a/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java @@ -60,6 +60,81 @@ public void testAuthentication_apiKey() throws Exception { .withHeader("Authorization", "Bearer my-api-key")); } + @Test + public void testAuthentication_resourceOwnerPasswordWithClientSecret() throws Exception { + // 1. Mock /.well-known/openid-configuration + mockServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/.well-known/openid-configuration") + ).respond( + org.mockserver.model.HttpResponse.response() + .withStatusCode(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"clientId\": \"my-client-id\", \"href\": \"http://localhost:" + mockServer.getLocalPort() + "/oidc-provider\"}") + ); + + // 2. Mock OIDC provider metadata + mockServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/oidc-provider") + ).respond( + org.mockserver.model.HttpResponse.response() + .withStatusCode(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"issuer\": \"http://localhost:" + mockServer.getLocalPort() + "\", \"token_endpoint\": \"http://localhost:" + mockServer.getLocalPort() + "/token\"}") + ); + + // 3. Mock Token Endpoint + mockServer.when( + HttpRequest.request() + .withMethod("POST") + .withPath("/token") + ).respond( + org.mockserver.model.HttpResponse.response() + .withStatusCode(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"access_token\": \"secret-token\", \"token_type\": \"Bearer\", \"expires_in\": 3600}") + ); + + var authz = Authentication.resourceOwnerPasswordCredentials("my-client-secret", "my-user", "my-pass", Collections.emptyList()); + var transportOptions = new RestTransportOptions( + "http", "localhost", mockServer.getLocalPort(), + Collections.emptyMap(), authz.getTokenProvider(noAuthTransport), null, new Timeout()); + + try (final var restClient = new DefaultRestTransport(transportOptions)) { + restClient.performRequest(null, SimpleEndpoint.sideEffect( + request -> "GET", request -> "/", request -> null)); + } catch (WeaviateApiException ex) { + if (ex.httpStatusCode() != 404) { + Assertions.fail("unexpected error", ex); + } + } + + // Verify token request had both password grant and client authentication + mockServer.verify( + HttpRequest.request() + .withMethod("POST") + .withPath("/token") + .withBody(org.mockserver.model.ParameterBody.params( + org.mockserver.model.Parameter.param("grant_type", "password"), + org.mockserver.model.Parameter.param("username", "my-user"), + org.mockserver.model.Parameter.param("password", "my-pass"), + org.mockserver.model.Parameter.param("client_id", "my-client-id"), + org.mockserver.model.Parameter.param("client_secret", "my-client-secret"), + org.mockserver.model.Parameter.param("scope", "offline_access") + )) + ); + + // Verify the actual request used the obtained token + mockServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/") + .withHeader("Authorization", "Bearer secret-token")); + } + @After public void stopMockServer() throws Exception { mockServer.stop(); diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/batch/BatchContextTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/batch/BatchContextTest.java index 80ab367ce..74703a0f2 100644 --- a/src/test/java/io/weaviate/client6/v1/api/collections/batch/BatchContextTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/collections/batch/BatchContextTest.java @@ -20,7 +20,11 @@ import org.assertj.core.api.Assertions; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -95,6 +99,19 @@ private StreamObserver createStream(StreamObserver recv) { return in; } + @Rule + public TestName currentTest = new TestName(); + + private boolean testFailed; + + @Rule + public TestWatcher __ = new TestWatcher() { + @Override + protected void failed(Throwable e, Description description) { + testFailed = true; + } + }; + /** * Create new unstarted context with default maxSizeBytes, collection * descriptor, and collection handle defaults. @@ -102,6 +119,7 @@ private StreamObserver createStream(StreamObserver recv) { @Before public void startContext() throws InterruptedException { log.debug("===================startContext=================="); + log.debug(currentTest.getMethodName()); assert !Thread.currentThread().isInterrupted() : "main thread interrupted"; assert REQUEST_QUEUE.isEmpty() : "stream contains incoming message " + REQUEST_QUEUE.peek(); @@ -117,12 +135,21 @@ public void startContext() throws InterruptedException { context.start(); in.expectMessage(START); - out.emitEvent(Event.STARTED); + out.emitEventAsync(Event.STARTED); } @After public void reset() throws Exception { - if (!contextClosed) { + log.atDebug() + .addKeyValue("contextClosed", contextClosed) + .addKeyValue("testFailed", testFailed) + .log("Begin test cleanup"); + + // Do not attempt to close the context if it has been previously closed + // by the test or the test has failed. In the latter case closing the + // context may lead to a deadlock if the case hasn't scheduled Results + // for all submitted messages. + if (!contextClosed && !testFailed) { closeContext(); } @@ -133,11 +160,14 @@ public void reset() throws Exception { // This resets the interrupted flag, allowing use to await the executors. Thread.interrupted(); - backgroundThread.shutdownNow(); - backgroundThread.awaitTermination(10, TimeUnit.SECONDS); + try { + backgroundThread.shutdownNow(); + backgroundThread.awaitTermination(10, TimeUnit.SECONDS); - eventThread.shutdownNow(); - eventThread.awaitTermination(10, TimeUnit.SECONDS); + eventThread.shutdownNow(); + eventThread.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException ignored) { + } REQUEST_QUEUE.clear(); @@ -159,7 +189,7 @@ private void closeContext() throws Exception { try { context.close(); - eof.get(); + eof.get(5, TimeUnit.SECONDS); } finally { contextClosed = true; } @@ -175,17 +205,21 @@ public void test_sendOneBatch() throws Exception { // BatchContext should flush the current batch once it hits its limit. // We will ack all items in the batch and send successful result for each one. List received = recvDataAndAck(); + out.beforeEof(new Event.Results(received, Collections.emptyMap())); + Assertions.assertThat(tasks) .extracting(TaskHandle::id).containsExactlyInAnyOrderElementsOf(received); - Assertions.assertThat(tasks) - .extracting(TaskHandle::isAcked).allMatch(CompletableFuture::isDone); - out.beforeEof(new Event.Results(received, Collections.emptyMap())); + CompletableFuture[] tasksAcked = tasks.stream() + .map(TaskHandle::isAcked).toArray(CompletableFuture[]::new); + Assertions.assertThat(CompletableFuture.allOf(tasksAcked)) + .succeedsWithin(5, TimeUnit.SECONDS); // Since MockServer runs in the same thread as this test, // the context will be updated before the last emitEvent returns. closeContext(); + // By the time context.close() returns all tasks MUST have results set. Assertions.assertThat(tasks).extracting(TaskHandle::result) .allMatch(CompletableFuture::isDone) .extracting(CompletableFuture::get).extracting(TaskHandle.Result::error) @@ -207,8 +241,11 @@ public void test_drainOnClose() throws Exception { List received = recvDataAndAck(); Assertions.assertThat(tasks).extracting(TaskHandle::id) .containsExactlyInAnyOrderElementsOf(received); - Assertions.assertThat(tasks).extracting(TaskHandle::isAcked) - .allMatch(CompletableFuture::isDone); + + CompletableFuture[] tasksAcked = tasks.stream() + .map(TaskHandle::isAcked).toArray(CompletableFuture[]::new); + Assertions.assertThat(CompletableFuture.allOf(tasksAcked)) + .succeedsWithin(5, TimeUnit.SECONDS); } catch (Exception e) { throw new RuntimeException(e); } @@ -273,6 +310,7 @@ public void test_backoffBacklog() throws Exception { int batchSizeNew = BATCH_SIZE / 2; // Force the last BATCH_SIZE / 2 - 1 items to be transferred to the backlog. + // Await for this event to be processed before moving forward. out.emitEvent(new Event.Backoff(batchSizeNew)); // The next item will go on the backlog and the trigger a flush, @@ -298,14 +336,14 @@ public void test_backoffBacklog() throws Exception { @Test public void test_reconnect_onShutdown() throws Exception { - out.emitEvent(Event.SHUTTING_DOWN); + out.emitEventAsync(Event.SHUTTING_DOWN); in.expectMessage(STOP); out.eof(true); in.expectMessage(START); // Not strictly necessary -- we can close the context // before a new connection is established. - out.emitEvent(Event.STARTED); + out.emitEventAsync(Event.STARTED); } @Test @@ -319,18 +357,18 @@ public void test_reconnect_onOom() throws Exception { // Respond with OOM and wait for the client to close its end of the stream. in.expectMessage(DATA); - out.emitEvent(new Event.Oom(0)); + out.emitEventAsync(new Event.Oom(0)); // Close the server's end of the stream. in.expectMessage(STOP); // Allow the client to reconnect to another "instance" and Ack the batch. in.expectMessage(START); - out.emitEvent(Event.STARTED); + out.emitEventAsync(Event.STARTED); recvDataAndAck(); List submitted = tasks.stream().map(TaskHandle::id).toList(); - out.emitEvent(new Event.Results(submitted, Collections.emptyMap())); + out.emitEventAsync(new Event.Results(submitted, Collections.emptyMap())); } @Test @@ -347,7 +385,7 @@ public void test_reconnect_onStreamHangup() throws Exception { // The client should try to reconnect, because the context is still open. in.expectMessage(START); - out.emitEvent(Event.STARTED); + out.emitEventAsync(Event.STARTED); // The previous batch hasn't been acked, so we should expect to receive it // again. @@ -358,7 +396,7 @@ public void test_reconnect_onStreamHangup() throws Exception { // in the queue to wake the sender up. out.hangup(); in.expectMessage(START); - out.emitEvent(Event.STARTED); + out.emitEventAsync(Event.STARTED); tasks.add(context.add(WeaviateObject.of())); recvDataAndAck(); @@ -399,7 +437,7 @@ public void test_reconnect_DrainAfterStreamHangup() throws Exception { // When the server starts accepting connections again, the client should // drain the remaining BATCH_SIZE+1 objects as we close the context. in.expectMessage(START); - out.emitEvent(Event.STARTED); + out.emitEventAsync(Event.STARTED); Future backgroundAcks = backgroundThread.submit(() -> { try { recvDataAndAck(); @@ -426,7 +464,7 @@ public void test_reconnect_DrainAfterStreamHangup() throws Exception { public void test_closeAfterStreamHangup() throws Exception { out.hangup(); in.expectMessage(START); - out.emitEvent(Event.STARTED); + out.emitEventAsync(Event.STARTED); } @Test @@ -464,7 +502,7 @@ public void test_startAfterClose() throws Exception { */ private List recvDataAndAck() throws InterruptedException { List received = recvData(); - out.emitEvent(new Event.Acks(received)); + out.emitEventAsync(new Event.Acks(received)); return received; } @@ -492,7 +530,16 @@ private static final class OutboundStream { this.eventThread = eventThread; } - CompletableFuture emitEvent(Event event) { + /** Emit event on the current thread. */ + void emitEvent(Event event) { + assert event != Event.EOF : "must not use synthetic EOF event"; + assert !(event instanceof Event.StreamHangup) : "must not use synthetic StreamHangup event"; + + stream.onNext(event); + } + + /** Emit event on the {@link #eventThread}. */ + CompletableFuture emitEventAsync(Event event) { assert event != Event.EOF : "must not use synthetic EOF event"; assert !(event instanceof Event.StreamHangup) : "must not use synthetic StreamHangup event"; @@ -522,7 +569,7 @@ CompletableFuture eof(boolean ok) { if (ok) { // These are guaranteed to finish before onCompleted, // as eventThread is just 1 thread. - pendingEvents.forEach(this::emitEvent); + pendingEvents.forEach(this::emitEventAsync); } return CompletableFuture.runAsync(stream::onCompleted, eventThread); } diff --git a/src/test/java/io/weaviate/client6/v1/internal/rest/ProxyTest.java b/src/test/java/io/weaviate/client6/v1/internal/rest/ProxyTest.java new file mode 100644 index 000000000..2a61e51e8 --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/internal/rest/ProxyTest.java @@ -0,0 +1,142 @@ +package io.weaviate.client6.v1.internal.rest; + +import io.weaviate.client6.v1.api.WeaviateApiException; +import io.weaviate.client6.v1.api.Config; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.internal.Proxy; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpRequest; +import org.mockserver.model.HttpResponse; + +import java.io.IOException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockserver.model.HttpForward.forward; + +public class ProxyTest { + private ClientAndServer targetServer; + private ClientAndServer proxyServer; + private WeaviateClient client; + + @Before + public void setUp() { + targetServer = ClientAndServer.startClientAndServer(0); + proxyServer = ClientAndServer.startClientAndServer(0); + + // Set up target server to return a success response + targetServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/.well-known/live")) + .respond( + HttpResponse.response() + .withStatusCode(200)); + + targetServer.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")) + .respond( + HttpResponse.response() + .withStatusCode(200) + .withBody("{\"version\": \"1.32.0\"}")); + + // Set up proxy server to forward requests to the target server + proxyServer.when( + HttpRequest.request()) + .forward( + forward() + .withHost("localhost") + .withPort(targetServer.getLocalPort()) + .withScheme(org.mockserver.model.HttpForward.Scheme.HTTP)); + + Config config = Config.of(c -> c + .scheme("http") + .httpHost("localhost") + .httpPort(targetServer.getLocalPort()) + .grpcHost("localhost") + .grpcPort(targetServer.getLocalPort()) + .proxy(new Proxy("localhost", proxyServer.getLocalPort())) + .timeout(5) + ); + + client = new WeaviateClient(config); + } + + @Test + public void testClientInitializationWithProxy() { + // This test verifies that the client can be successfully created. + // The WeaviateClient constructor performs REST calls to /v1/.well-known/live + // and /v1/meta to verify the connection and version support. + // If these calls fail, the constructor throws a WeaviateConnectException. + // Since setUp() already creates a client using the proxy, we just need to + // verify it was initialized correctly. + assertThat(client).isNotNull(); + + // Verify that the initialization calls went through the proxy + proxyServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/.well-known/live")); + proxyServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")); + } + + @Test + public void testRestProxy() throws IOException { + // Perform a request that should go through the proxy + client.meta(); + + // Verify that the proxy server received the request + proxyServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")); + + // Verify that the target server also received the request (forwarded by proxy) + targetServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/meta")); + } + + @Test + public void testProxyConfiguration() { + // In this test, we verify that the client has the proxy configured. + assertThat(client.getConfig().proxy()).isNotNull(); + assertThat(client.getConfig().proxy().port()).isEqualTo((long) proxyServer.getLocalPort()); + } + + @Test + public void testGrpcProxy() { + // gRPC proxying via HTTP CONNECT. + // DefaultGrpcTransport uses a custom ProxyDetector which returns a + // HttpConnectProxiedSocketAddress when a proxy is configured. + + // To verify that gRPC proxying is correctly set up, we check the configuration. + // Since actual CONNECT verification via MockServer is tricky in this setup, + // we focus on ensuring the client is correctly initialized with the proxy. + assertThatThrownBy(() -> client.collections.use("Test").size()) + .isInstanceOf(WeaviateApiException.class) + .hasMessageContaining("UNAVAILABLE: Network closed"); + } + + @After + public void tearDown() throws Exception { + if (client != null) { + client.close(); + } + if (proxyServer != null) { + proxyServer.stop(); + } + if (targetServer != null) { + targetServer.stop(); + } + } +} diff --git a/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java b/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java index 587cb2548..7991078cd 100644 --- a/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java +++ b/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java @@ -6,6 +6,7 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; +import io.weaviate.client6.v1.internal.Proxy; import io.weaviate.client6.v1.internal.rest.BooleanEndpoint; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.RestTransport; @@ -25,7 +26,7 @@ public interface AssertFunction { void apply(String method, String requestUrl, String body, Map queryParameters); } - private List> requests = new ArrayList<>(); + private final List> requests = new ArrayList<>(); public void assertNext(AssertFunction... assertions) { var assertN = Math.min(assertions.length, requests.size()); @@ -60,4 +61,9 @@ public CompletableFuture performReq @Override public void close() throws IOException { } + + @Override + public Proxy getProxy() { + return null; + } }