Skip to content

Commit 2c508a2

Browse files
committed
Try to connect to IPv4 and IPv6 simultaneously
Implement Happy Eyeballs (RFC 8305) connection algorithm to deal better with IPv4 and IPv6 simultaneously. You can still select IPv4 or IPv6 only via API and connecting directly to a single IPv4 or IPv6 address bypasses this algorithm. Additionally providing a proxy will bypass Happy Eyeballs as well.
1 parent 5ca3321 commit 2c508a2

3 files changed

Lines changed: 569 additions & 36 deletions

File tree

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
package com.trilead.ssh2.transport;
2+
3+
import java.io.IOException;
4+
import java.net.Inet6Address;
5+
import java.net.InetAddress;
6+
import java.net.InetSocketAddress;
7+
import java.net.Socket;
8+
import java.net.UnknownHostException;
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.concurrent.Callable;
12+
import java.util.concurrent.CancellationException;
13+
import java.util.concurrent.ExecutionException;
14+
import java.util.concurrent.ExecutorService;
15+
import java.util.concurrent.Executors;
16+
import java.util.concurrent.Future;
17+
import java.util.concurrent.atomic.AtomicBoolean;
18+
19+
import com.trilead.ssh2.IpVersion;
20+
21+
/**
22+
* Implements Happy Eyeballs (RFC 8305) connection algorithm.
23+
*
24+
* This algorithm improves connection times when both IPv4 and IPv6
25+
* addresses are available by:
26+
* <ol>
27+
* <li>Resolving all addresses (A and AAAA records)</li>
28+
* <li>Starting IPv6 connection attempts first</li>
29+
* <li>After a short delay, starting IPv4 attempts in parallel</li>
30+
* <li>Using whichever connection succeeds first</li>
31+
* <li>Cancelling/closing remaining attempts</li>
32+
* </ol>
33+
*/
34+
class HappyEyeballsConnector {
35+
36+
static final int CONNECTION_ATTEMPT_DELAY_MS = 250;
37+
38+
private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(r -> {
39+
Thread t = new Thread(r, "HappyEyeballs-Connector");
40+
t.setDaemon(true);
41+
return t;
42+
});
43+
44+
@FunctionalInterface
45+
interface DnsResolver {
46+
InetAddress[] resolve(String hostname) throws UnknownHostException;
47+
}
48+
49+
@FunctionalInterface
50+
interface SocketFactory {
51+
Socket createSocket();
52+
}
53+
54+
private final DnsResolver dnsResolver;
55+
private final SocketFactory socketFactory;
56+
private final int connectionAttemptDelayMs;
57+
58+
HappyEyeballsConnector() {
59+
this(InetAddress::getAllByName, Socket::new, CONNECTION_ATTEMPT_DELAY_MS);
60+
}
61+
62+
HappyEyeballsConnector(DnsResolver dnsResolver, SocketFactory socketFactory, int connectionAttemptDelayMs) {
63+
this.dnsResolver = dnsResolver;
64+
this.socketFactory = socketFactory;
65+
this.connectionAttemptDelayMs = connectionAttemptDelayMs;
66+
}
67+
68+
/**
69+
* Connect to the given hostname and port using Happy Eyeballs algorithm.
70+
*
71+
* @param hostname the hostname to connect to
72+
* @param port the port to connect to
73+
* @param connectTimeout the connection timeout in milliseconds (0 for infinite)
74+
* @param ipVersion controls which IP versions to use
75+
* @return a connected socket
76+
* @throws IOException if connection fails
77+
*/
78+
Socket connect(String hostname, int port, int connectTimeout, IpVersion ipVersion)
79+
throws IOException {
80+
81+
List<InetAddress> addresses = resolveAddresses(hostname, ipVersion);
82+
83+
if (addresses.isEmpty()) {
84+
throw new UnknownHostException("No addresses found for: " + hostname);
85+
}
86+
87+
if (addresses.size() == 1) {
88+
return connectSimple(addresses.get(0), port, connectTimeout);
89+
}
90+
91+
List<InetAddress> sortedAddresses = interleaveByFamily(addresses);
92+
return connectWithRacing(sortedAddresses, port, connectTimeout);
93+
}
94+
95+
private List<InetAddress> resolveAddresses(String hostname, IpVersion ipVersion)
96+
throws UnknownHostException {
97+
InetAddress[] allAddresses = dnsResolver.resolve(hostname);
98+
return filterByIpVersion(allAddresses, ipVersion);
99+
}
100+
101+
static List<InetAddress> filterByIpVersion(InetAddress[] addresses, IpVersion ipVersion) {
102+
List<InetAddress> filtered = new ArrayList<>();
103+
104+
for (InetAddress addr : addresses) {
105+
boolean isIPv6 = addr instanceof Inet6Address;
106+
107+
if (ipVersion == IpVersion.IPV4_ONLY && isIPv6) {
108+
continue;
109+
}
110+
if (ipVersion == IpVersion.IPV6_ONLY && !isIPv6) {
111+
continue;
112+
}
113+
filtered.add(addr);
114+
}
115+
116+
return filtered;
117+
}
118+
119+
static List<InetAddress> interleaveByFamily(List<InetAddress> addresses) {
120+
List<InetAddress> ipv6 = new ArrayList<>();
121+
List<InetAddress> ipv4 = new ArrayList<>();
122+
123+
for (InetAddress addr : addresses) {
124+
if (addr instanceof Inet6Address) {
125+
ipv6.add(addr);
126+
} else {
127+
ipv4.add(addr);
128+
}
129+
}
130+
131+
List<InetAddress> result = new ArrayList<>();
132+
int maxSize = Math.max(ipv6.size(), ipv4.size());
133+
134+
for (int i = 0; i < maxSize; i++) {
135+
if (i < ipv6.size())
136+
result.add(ipv6.get(i));
137+
if (i < ipv4.size())
138+
result.add(ipv4.get(i));
139+
}
140+
141+
return result;
142+
}
143+
144+
private Socket connectWithRacing(List<InetAddress> addresses, int port, int connectTimeout)
145+
throws IOException {
146+
147+
AtomicBoolean winnerFound = new AtomicBoolean(false);
148+
List<Future<Socket>> futures = new ArrayList<>();
149+
List<Socket> socketsToClose = new ArrayList<>();
150+
151+
try {
152+
for (int i = 0; i < addresses.size(); i++) {
153+
InetAddress address = addresses.get(i);
154+
int delay = i * connectionAttemptDelayMs;
155+
156+
Callable<Socket> task = createConnectionTask(
157+
address, port, connectTimeout, delay, winnerFound, socketsToClose);
158+
futures.add(EXECUTOR.submit(task));
159+
}
160+
161+
return waitForFirstSuccess(futures);
162+
163+
} finally {
164+
for (Future<Socket> future : futures) {
165+
future.cancel(true);
166+
}
167+
168+
synchronized (socketsToClose) {
169+
for (Socket socket : socketsToClose) {
170+
closeQuietly(socket);
171+
}
172+
}
173+
}
174+
}
175+
176+
private Callable<Socket> createConnectionTask(
177+
InetAddress address,
178+
int port,
179+
int connectTimeout,
180+
int delay,
181+
AtomicBoolean winnerFound,
182+
List<Socket> socketsToClose) {
183+
184+
return () -> {
185+
if (delay > 0) {
186+
Thread.sleep(delay);
187+
}
188+
189+
if (winnerFound.get()) {
190+
throw new CancellationException("Another connection won");
191+
}
192+
193+
Socket socket = socketFactory.createSocket();
194+
synchronized (socketsToClose) {
195+
socketsToClose.add(socket);
196+
}
197+
198+
try {
199+
socket.connect(new InetSocketAddress(address, port), connectTimeout);
200+
socket.setSoTimeout(0);
201+
202+
if (winnerFound.compareAndSet(false, true)) {
203+
synchronized (socketsToClose) {
204+
socketsToClose.remove(socket);
205+
}
206+
return socket;
207+
} else {
208+
closeQuietly(socket);
209+
throw new CancellationException("Another connection won");
210+
}
211+
} catch (IOException e) {
212+
closeQuietly(socket);
213+
synchronized (socketsToClose) {
214+
socketsToClose.remove(socket);
215+
}
216+
throw e;
217+
}
218+
};
219+
}
220+
221+
private Socket waitForFirstSuccess(List<Future<Socket>> futures) throws IOException {
222+
IOException lastException = null;
223+
List<Future<Socket>> pending = new ArrayList<>(futures);
224+
225+
while (!pending.isEmpty()) {
226+
Future<Socket> completed = null;
227+
228+
for (Future<Socket> future : pending) {
229+
if (future.isDone()) {
230+
completed = future;
231+
break;
232+
}
233+
}
234+
235+
if (completed == null) {
236+
try {
237+
Thread.sleep(10);
238+
} catch (InterruptedException e) {
239+
Thread.currentThread().interrupt();
240+
throw new IOException("Connection interrupted", e);
241+
}
242+
continue;
243+
}
244+
245+
pending.remove(completed);
246+
247+
try {
248+
Socket socket = completed.get();
249+
if (socket != null && socket.isConnected()) {
250+
return socket;
251+
}
252+
} catch (CancellationException e) {
253+
// Task was cancelled, try next
254+
} catch (ExecutionException e) {
255+
Throwable cause = e.getCause();
256+
if (cause instanceof IOException) {
257+
lastException = (IOException) cause;
258+
} else if (cause instanceof InterruptedException) {
259+
Thread.currentThread().interrupt();
260+
throw new IOException("Connection interrupted", cause);
261+
} else {
262+
lastException = new IOException("Connection failed", cause);
263+
}
264+
} catch (InterruptedException e) {
265+
Thread.currentThread().interrupt();
266+
throw new IOException("Connection interrupted", e);
267+
}
268+
}
269+
270+
if (lastException != null) {
271+
throw lastException;
272+
}
273+
throw new IOException("All connection attempts failed");
274+
}
275+
276+
private Socket connectSimple(InetAddress address, int port, int timeout) throws IOException {
277+
Socket socket = socketFactory.createSocket();
278+
try {
279+
socket.connect(new InetSocketAddress(address, port), timeout);
280+
socket.setSoTimeout(0);
281+
return socket;
282+
} catch (IOException e) {
283+
closeQuietly(socket);
284+
throw e;
285+
}
286+
}
287+
288+
private static void closeQuietly(Socket socket) {
289+
if (socket != null) {
290+
try {
291+
socket.close();
292+
} catch (IOException ignored) {
293+
}
294+
}
295+
}
296+
}

src/main/java/com/trilead/ssh2/transport/TransportManager.java

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
import com.trilead.ssh2.ExtensionInfo;
55
import com.trilead.ssh2.packets.PacketExtInfo;
66
import java.io.IOException;
7-
import java.net.Inet6Address;
8-
import java.net.InetAddress;
9-
import java.net.InetSocketAddress;
107
import java.net.Socket;
118
import java.security.SecureRandom;
129
import java.util.Vector;
@@ -282,22 +279,6 @@ public void close(Throwable cause, boolean useDisconnectPacket)
282279
}
283280
}
284281

285-
private static InetAddress getIPv4Address(InetAddress[] addresses) {
286-
for (InetAddress address : addresses) {
287-
if (! (address instanceof Inet6Address)) {
288-
return address;
289-
}
290-
}
291-
return null;
292-
}
293-
private static Inet6Address getIPv6Address(InetAddress[] addresses) {
294-
for (InetAddress address : addresses) {
295-
if (address instanceof Inet6Address) {
296-
return (Inet6Address) address;
297-
}
298-
}
299-
return null;
300-
}
301282

302283
private void establishConnection(ProxyData proxyData, int connectTimeout, IpVersion ipVersion) throws IOException
303284
{
@@ -310,23 +291,7 @@ private void establishConnection(ProxyData proxyData, int connectTimeout, IpVers
310291
private static Socket connectDirect(String hostname, int port, int connectTimeout, IpVersion ipVersion)
311292
throws IOException
312293
{
313-
Socket sock = new Socket();
314-
InetAddress addr;
315-
if (ipVersion == IpVersion.IPV4_ONLY)
316-
{
317-
addr = getIPv4Address(InetAddress.getAllByName(hostname));
318-
}
319-
else if (ipVersion == IpVersion.IPV6_ONLY)
320-
{
321-
addr = getIPv6Address(InetAddress.getAllByName(hostname));
322-
}
323-
else // Assume (ipVersion == IpVersion.IPV4_AND_IPV6), the default.
324-
{
325-
addr = InetAddress.getByName(hostname);
326-
}
327-
sock.connect(new InetSocketAddress(addr, port), connectTimeout);
328-
sock.setSoTimeout(0);
329-
return sock;
294+
return new HappyEyeballsConnector().connect(hostname, port, connectTimeout, ipVersion);
330295
}
331296

332297
public void initialize(CryptoWishList cwl, ServerHostKeyVerifier verifier, DHGexParameters dhgex,

0 commit comments

Comments
 (0)