diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java index 2825c834ffc5b0..ab83428b0c4d7b 100644 --- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java @@ -361,6 +361,22 @@ private static void addSslRemoteConfig( final String sslEngineProviderName = CustomSSLEngineProvider.class.getCanonicalName(); + LOG.debug( + "Creating RPC SSL configuration with enabled={}, protocol={}, " + + "enabledAlgorithms={}, keyStoreConfigured={}, keyStoreType={}, " + + "trustStoreConfigured={}, trustStoreType={}, certFingerprintsConfigured={}, " + + "requireMutualAuthentication={}, sslEngineProvider={}", + enableSSLConfig, + sslProtocol, + sslAlgorithmsString, + sslKeyStore != null, + sslKeyStoreType, + sslTrustStore != null, + sslTrustStoreType, + sslCertFingerprintString != null && !sslCertFingerprintString.isEmpty(), + true, + sslEngineProviderName); + configBuilder .add("pekko {") .add(" remote.classic {") diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java index c89dd5eef7d91a..dd359a24fb4320 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java @@ -35,6 +35,9 @@ import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslProvider; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.util.FingerprintTrustManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import javax.annotation.Nullable; import javax.net.ServerSocketFactory; import javax.net.SocketFactory; @@ -42,6 +45,8 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLServerSocket; import javax.net.ssl.SSLServerSocketFactory; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManagerFactory; import java.io.File; @@ -49,6 +54,7 @@ import java.io.InputStream; import java.net.InetAddress; import java.net.ServerSocket; +import java.net.Socket; import java.nio.file.Files; import java.security.KeyStore; import java.security.KeyStoreException; @@ -67,6 +73,8 @@ /** Common utilities to manage SSL transport settings. */ public class SSLUtils { + private static final Logger LOG = LoggerFactory.getLogger(SSLUtils.class); + /** * Creates a factory for SSL Server Sockets from the given configuration. SSL Server Sockets are * always part of internal communication. @@ -96,7 +104,11 @@ public static SocketFactory createSSLClientSocketFactory(Configuration config) throw new IllegalConfigurationException("SSL is not enabled"); } - return sslContext.getSocketFactory(); + String[] protocols = getEnabledProtocols(config); + String[] cipherSuites = getEnabledCipherSuites(config); + + SSLSocketFactory factory = sslContext.getSocketFactory(); + return new ConfiguringSSLSocketFactory(factory, protocols, cipherSuites); } /** Creates a SSLEngineFactory to be used by internal communication server endpoints. */ @@ -357,6 +369,17 @@ private static SslContext createInternalNettySSLContext( int sessionCacheSize = config.get(SecurityOptions.SSL_INTERNAL_SESSION_CACHE_SIZE); int sessionTimeoutMs = config.get(SecurityOptions.SSL_INTERNAL_SESSION_TIMEOUT); + LOG.debug( + "Creating internal SSL context with provider={}, clientMode={}, clientAuth={}, " + + "protocols={}, cipherSuites={}, sessionCacheSize={}, sessionTimeoutMs={}", + provider, + clientMode, + ClientAuth.REQUIRE, + Arrays.toString(sslProtocols), + ciphers, + sessionCacheSize, + sessionTimeoutMs); + KeyManagerFactory kmf = getKeyManagerFactory(config, true, provider); ClientAuth clientAuth = ClientAuth.REQUIRE; @@ -421,6 +444,16 @@ public static SslContext createRestNettySSLContext( String[] sslProtocols = getEnabledProtocols(config); List ciphers = Arrays.asList(getEnabledCipherSuites(config)); + LOG.debug( + "Creating REST SSL context with provider={}, clientMode={}, clientAuth={}, " + + "protocols={}, cipherSuites={}, verifyHostname={}", + provider, + clientMode, + clientAuth, + Arrays.toString(sslProtocols), + ciphers, + clientMode ? config.get(SecurityOptions.SSL_REST_VERIFY_HOSTNAME) : null); + final SslContextBuilder sslContextBuilder; if (clientMode) { sslContextBuilder = SslContextBuilder.forClient(); @@ -433,22 +466,21 @@ public static SslContext createRestNettySSLContext( } else { KeyManagerFactory kmf = getKeyManagerFactory(config, false, provider); sslContextBuilder = SslContextBuilder.forServer(kmf); + if (clientAuth != ClientAuth.NONE) { + sslContextBuilder.clientAuth(clientAuth); + } } if (clientMode || clientAuth != ClientAuth.NONE) { Optional tmf = getTrustManagerFactory(config, false); - tmf.map( - // Use specific ciphers and protocols if SSL is configured with self-signed - // certificates (user-supplied truststore) - tm -> - sslContextBuilder - .trustManager(tm) - .protocols(sslProtocols) - .ciphers(ciphers) - .clientAuth(clientAuth)); - } - - return sslContextBuilder.sslProvider(provider).build(); + tmf.ifPresent(sslContextBuilder::trustManager); + } + + return sslContextBuilder + .sslProvider(provider) + .protocols(sslProtocols) + .ciphers(ciphers) + .build(); } // ------------------------------------------------------------------------ @@ -476,6 +508,72 @@ private static String getAndCheckOption( // Wrappers for socket factories that additionally configure the sockets // ------------------------------------------------------------------------ + private static class ConfiguringSSLSocketFactory extends SSLSocketFactory { + + private final SSLSocketFactory sslSocketFactory; + private final String[] protocols; + private final String[] cipherSuites; + + ConfiguringSSLSocketFactory( + SSLSocketFactory sslSocketFactory, String[] protocols, String[] cipherSuites) { + this.sslSocketFactory = sslSocketFactory; + this.protocols = protocols; + this.cipherSuites = cipherSuites; + } + + @Override + public String[] getDefaultCipherSuites() { + return sslSocketFactory.getDefaultCipherSuites(); + } + + @Override + public String[] getSupportedCipherSuites() { + return sslSocketFactory.getSupportedCipherSuites(); + } + + @Override + public Socket createSocket() throws IOException { + return configureSocket(sslSocketFactory.createSocket()); + } + + @Override + public Socket createSocket(Socket socket, String host, int port, boolean autoClose) + throws IOException { + return configureSocket(sslSocketFactory.createSocket(socket, host, port, autoClose)); + } + + @Override + public Socket createSocket(String host, int port) throws IOException { + return configureSocket(sslSocketFactory.createSocket(host, port)); + } + + @Override + public Socket createSocket(String host, int port, InetAddress localHost, int localPort) + throws IOException { + return configureSocket(sslSocketFactory.createSocket(host, port, localHost, localPort)); + } + + @Override + public Socket createSocket(InetAddress host, int port) throws IOException { + return configureSocket(sslSocketFactory.createSocket(host, port)); + } + + @Override + public Socket createSocket( + InetAddress address, int port, InetAddress localAddress, int localPort) + throws IOException { + return configureSocket( + sslSocketFactory.createSocket(address, port, localAddress, localPort)); + } + + private Socket configureSocket(Socket socket) { + SSLSocket sslSocket = (SSLSocket) socket; + sslSocket.setEnabledProtocols(protocols); + sslSocket.setEnabledCipherSuites(cipherSuites); + return sslSocket; + } + } + private static class ConfiguringSSLServerSocketFactory extends ServerSocketFactory { private final SSLServerSocketFactory sslServerSocketFactory; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java index 43e7de101ee9d9..0e2915b95c35a4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java @@ -33,11 +33,14 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSocket; import java.io.File; import java.io.InputStream; import java.net.ServerSocket; +import java.net.Socket; import java.nio.file.Files; import java.security.KeyStore; import java.security.KeyStoreException; @@ -175,6 +178,27 @@ void testRESTSSLConfigCipherAlgorithms(String sslProvider) throws Exception { assertThat(cipherSuites).containsExactlyInAnyOrder(testSSLAlgorithms.split(",")); } + @Test + void testRestServerAppliesConfiguredProtocolsAndCipherSuites() throws Exception { + final Configuration config = createRestSslConfigWithKeyStore("JDK"); + config.set(SecurityOptions.SSL_PROTOCOL, "TLSv1.2"); + config.set( + SecurityOptions.SSL_ALGORITHMS, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"); + + final JdkSslContext nettySSLContext = + (JdkSslContext) + SSLUtils.createRestNettySSLContext(config, false, ClientAuth.NONE, JDK); + final SSLEngine sslEngine = + checkNotNull(nettySSLContext).newEngine(UnpooledByteBufAllocator.DEFAULT); + + assertThat(sslEngine.getEnabledProtocols()).containsExactly("TLSv1.2"); + assertThat(sslEngine.getEnabledCipherSuites()) + .containsExactlyInAnyOrder( + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"); + } + // ------------------------ server -------------------------- /** Tests that REST Server SSL Engine is created given a valid SSL configuration. */ @@ -387,6 +411,29 @@ void testSetSSLVersionAndCipherSuitesForSSLServerSocket(String sslProvider) thro } } + @ParameterizedTest + @MethodSource("parameters") + void testSetSSLVersionAndCipherSuitesForSSLClientSocket(String sslProvider) throws Exception { + final Configuration clientConfig = + createInternalSslConfigWithKeyAndTrustStores(sslProvider); + + clientConfig.set(SecurityOptions.SSL_PROTOCOL, "TLSv1.1"); + clientConfig.set( + SecurityOptions.SSL_ALGORITHMS, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"); + + try (Socket socket = SSLUtils.createSSLClientSocketFactory(clientConfig).createSocket()) { + assertThat(socket).isInstanceOf(SSLSocket.class); + final SSLSocket sslSocket = (SSLSocket) socket; + + assertThat(sslSocket.getEnabledProtocols()).containsExactly("TLSv1.1"); + assertThat(sslSocket.getEnabledCipherSuites()) + .containsExactlyInAnyOrder( + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"); + } + } + /** Tests that {@link SSLHandlerFactory} is created correctly. */ @ParameterizedTest @MethodSource("parameters")