Skip to content
This repository has been archived by the owner on Dec 9, 2024. It is now read-only.

Commit

Permalink
Change DNS lookup method (#323) (#325)
Browse files Browse the repository at this point in the history
* Change DNS lookup method (#323)

* Fix checkstyle errors

* Improve exception handling.

* Unwrap and throw UnknownHostException from ExecutionException
Add testcases for handling UnknownHostException and TimeoutException

* Fix indentation
  • Loading branch information
npomaroli authored Jun 15, 2021
1 parent 1a736ce commit 968f29a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 78 deletions.
91 changes: 43 additions & 48 deletions src/main/java/com/hazelcast/kubernetes/DnsEndpointResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,66 +17,50 @@
package com.hazelcast.kubernetes;

import com.hazelcast.config.NetworkConfig;
import com.hazelcast.core.HazelcastException;
import com.hazelcast.logging.ILogger;
import com.hazelcast.cluster.Address;
import com.hazelcast.spi.discovery.DiscoveryNode;
import com.hazelcast.spi.discovery.SimpleDiscoveryNode;

import javax.naming.Context;
import javax.naming.NameNotFoundException;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.Attributes;
import javax.naming.directory.DirContext;
import javax.naming.directory.InitialDirContext;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

final class DnsEndpointResolver
extends HazelcastKubernetesDiscoveryStrategy.EndpointResolver {
// executor service for dns lookup calls
private static final ExecutorService DNS_LOOKUP_SERVICE = Executors.newCachedThreadPool();

private final String serviceDns;
private final int port;
private final DirContext dirContext;
private final int serviceDnsTimeout;

DnsEndpointResolver(ILogger logger, String serviceDns, int port, DirContext dirContext) {
DnsEndpointResolver(ILogger logger, String serviceDns, int port, int serviceDnsTimeout) {
super(logger);
this.serviceDns = serviceDns;
this.port = port;
this.dirContext = dirContext;
}

DnsEndpointResolver(ILogger logger, String serviceDns, int port, int serviceDnsTimeout) {
this(logger, serviceDns, port, createDirContext(serviceDnsTimeout));

}

@SuppressWarnings("checkstyle:magicnumber")
private static DirContext createDirContext(int serviceDnsTimeout) {
Hashtable<String, String> env = new Hashtable<String, String>();
env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.dns.DnsContextFactory");
env.put(Context.PROVIDER_URL, "dns:");
env.put("com.sun.jndi.dns.timeout.initial", String.valueOf(serviceDnsTimeout * 1000L));
try {
return new InitialDirContext(env);
} catch (NamingException e) {
throw new HazelcastException("Error while initializing DirContext", e);
}
this.serviceDnsTimeout = serviceDnsTimeout;
}

List<DiscoveryNode> resolve() {
try {
return lookup();
} catch (NameNotFoundException e) {
logger.warning(String.format("DNS lookup for serviceDns '%s' failed: name not found", serviceDns));
} catch (TimeoutException e) {
logger.warning(String.format("DNS lookup for serviceDns '%s' failed: DNS resolution timeout", serviceDns));
return Collections.emptyList();
} catch (UnknownHostException e) {
logger.warning(String.format("DNS lookup for serviceDns '%s' failed: unknown host", serviceDns));
return Collections.emptyList();
} catch (Exception e) {
logger.warning(String.format("DNS lookup for serviceDns '%s' failed", serviceDns), e);
Expand All @@ -85,20 +69,32 @@ List<DiscoveryNode> resolve() {
}

private List<DiscoveryNode> lookup()
throws NamingException, UnknownHostException {
throws UnknownHostException, InterruptedException, ExecutionException, TimeoutException {
Set<String> addresses = new HashSet<String>();
Attributes attributes = dirContext.getAttributes(serviceDns, new String[]{"SRV"});
Attribute srvAttribute = attributes.get("srv");
if (srvAttribute != null) {
NamingEnumeration<?> servers = srvAttribute.getAll();
while (servers.hasMore()) {
String server = (String) servers.next();
String serverHost = extractHost(server);
InetAddress address = InetAddress.getByName(serverHost);

Future<InetAddress[]> future = DNS_LOOKUP_SERVICE.submit(new Callable<InetAddress[]>() {
@Override
public InetAddress[] call() throws Exception {
return getAllInetAddresses();
}
});

try {
for (InetAddress address : future.get(serviceDnsTimeout, TimeUnit.SECONDS)) {
if (addresses.add(address.getHostAddress()) && logger.isFinestEnabled()) {
logger.finest("Found node service with address: " + address);
}
}
} catch (ExecutionException e) {
if (e.getCause() instanceof UnknownHostException) {
throw (UnknownHostException) e.getCause();
} else {
throw e;
}
} catch (TimeoutException e) {
// cancel DNS lookup
future.cancel(true);
throw e;
}

if (addresses.size() == 0) {
Expand All @@ -114,13 +110,12 @@ private List<DiscoveryNode> lookup()
}

/**
* Extracts host from the DNS record.
* <p>
* Sample record: "10 25 0 6235386366386436.my-release-hazelcast.default.svc.cluster.local".
* Do the actual lookup
* @return array of resolved inet addresses
* @throws UnknownHostException
*/
private static String extractHost(String server) {
String host = server.split(" ")[3];
return host.replaceAll("\\\\.$", "");
private InetAddress[] getAllInetAddresses() throws UnknownHostException {
return InetAddress.getAllByName(serviceDns);
}

private static int getHazelcastPort(int port) {
Expand Down
76 changes: 46 additions & 30 deletions src/test/java/com/hazelcast/kubernetes/DnsEndpointResolverTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,24 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;

import javax.naming.NameNotFoundException;
import javax.naming.NamingEnumeration;
import javax.naming.directory.Attribute;
import javax.naming.directory.Attributes;
import javax.naming.directory.DirContext;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@RunWith(PowerMockRunner.class)
Expand All @@ -47,45 +48,30 @@ public class DnsEndpointResolverTest {
private static final ILogger LOGGER = new NoLogFactory().getLogger("no");

private static final String SERVICE_DNS = "my-release-hazelcast.default.svc.cluster.local";
private static final int DEFAULT_SERVICE_DNS_TIMEOUT_SECONDS = 5;
private static final int TEST_DNS_TIMEOUT_SECONDS = 1;
private static final int UNSET_PORT = 0;
private static final int DEFAULT_PORT = 5701;
private static final int CUSTOM_PORT = 5702;
private static final String DNS_SERVER_1 = String.format("12345.%s", SERVICE_DNS);
private static final String DNS_SERVER_2 = String.format("6789.%s", SERVICE_DNS);
private static final String DNS_ENTRY_SERVER_1 = String.format("10 25 0 %s", DNS_SERVER_1);
private static final String DNS_ENTRY_SERVER_2 = String.format("10 25 0 %s", DNS_SERVER_2);
private static final String IP_SERVER_1 = "192.168.0.5";
private static final String IP_SERVER_2 = "192.168.0.6";

@Mock
private NamingEnumeration servers;
@Mock
private DirContext dirContext;

@Before
public void setUp()
throws Exception {
PowerMockito.mockStatic(InetAddress.class);

Attributes attributes = mock(Attributes.class);
when(dirContext.getAttributes(SERVICE_DNS, new String[]{"SRV"})).thenReturn(attributes);
Attribute attribute = mock(Attribute.class);
when(attributes.get("srv")).thenReturn(attribute);
when(attribute.getAll()).thenReturn(servers);
when(servers.next()).thenReturn(DNS_ENTRY_SERVER_1, DNS_ENTRY_SERVER_2);
when(servers.hasMore()).thenReturn(true, true, false);
InetAddress address1 = mock(InetAddress.class);
PowerMockito.when(InetAddress.getByName(DNS_SERVER_1)).thenReturn(address1);
InetAddress address2 = mock(InetAddress.class);
PowerMockito.when(InetAddress.getByName(DNS_SERVER_2)).thenReturn(address2);
when(address1.getHostAddress()).thenReturn(IP_SERVER_1);
when(address2.getHostAddress()).thenReturn(IP_SERVER_2);
PowerMockito.when(InetAddress.getAllByName(SERVICE_DNS)).thenReturn(new InetAddress[]{address1, address2});
}

@Test
public void resolve() {
// given
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(LOGGER, SERVICE_DNS, UNSET_PORT, dirContext);
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(LOGGER, SERVICE_DNS, UNSET_PORT, DEFAULT_SERVICE_DNS_TIMEOUT_SECONDS);

// when
List<DiscoveryNode> result = dnsEndpointResolver.resolve();
Expand All @@ -101,7 +87,7 @@ public void resolve() {
@Test
public void resolveCustomPort() {
// given
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(LOGGER, SERVICE_DNS, CUSTOM_PORT, dirContext);
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(LOGGER, SERVICE_DNS, CUSTOM_PORT, DEFAULT_SERVICE_DNS_TIMEOUT_SECONDS);

// when
List<DiscoveryNode> result = dnsEndpointResolver.resolve();
Expand All @@ -118,28 +104,58 @@ public void resolveCustomPort() {
public void resolveException()
throws Exception {
// given
when(dirContext.getAttributes(SERVICE_DNS, new String[]{"SRV"})).thenThrow(new NameNotFoundException());
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(LOGGER, SERVICE_DNS, UNSET_PORT, dirContext);
ILogger logger = mock(ILogger.class);
PowerMockito.when(InetAddress.getAllByName(SERVICE_DNS)).thenThrow(new UnknownHostException());
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(logger, SERVICE_DNS, UNSET_PORT, DEFAULT_SERVICE_DNS_TIMEOUT_SECONDS);

// when
List<DiscoveryNode> result = dnsEndpointResolver.resolve();

// then
assertEquals(0, result.size());
verify(logger).warning(String.format("DNS lookup for serviceDns '%s' failed: unknown host", SERVICE_DNS));
verify(logger, never()).warning(anyString(), any(Throwable.class));
}

@Test
public void resolveNotFound()
throws Exception {
// given
when(servers.hasMore()).thenReturn(false);
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(LOGGER, SERVICE_DNS, UNSET_PORT, dirContext);
PowerMockito.when(InetAddress.getAllByName(SERVICE_DNS)).thenReturn(new InetAddress[0]);
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(LOGGER, SERVICE_DNS, UNSET_PORT, DEFAULT_SERVICE_DNS_TIMEOUT_SECONDS);

// when
List<DiscoveryNode> result = dnsEndpointResolver.resolve();

// then
assertEquals(0, result.size());
}

@Test
public void resolveTimeout()
throws Exception {
// given
ILogger logger = mock(ILogger.class);
PowerMockito.when(InetAddress.getAllByName(SERVICE_DNS)).then(waitAndAnswer());
DnsEndpointResolver dnsEndpointResolver = new DnsEndpointResolver(logger, SERVICE_DNS, UNSET_PORT, TEST_DNS_TIMEOUT_SECONDS);

// when
List<DiscoveryNode> result = dnsEndpointResolver.resolve();

// then
assertEquals(0, result.size());
verify(logger).warning(String.format("DNS lookup for serviceDns '%s' failed: DNS resolution timeout", SERVICE_DNS));
verify(logger, never()).warning(anyString(), any(Throwable.class));
}

private static Answer<InetAddress[]> waitAndAnswer() {
return new Answer<InetAddress[]>() {
@Override
public InetAddress[] answer(InvocationOnMock invocation) throws Throwable {
Thread.sleep(TEST_DNS_TIMEOUT_SECONDS * 5 * 1000);
return new InetAddress[0];
}
};
}

private static Set<?> setOf(Object... objects) {
Expand Down

0 comments on commit 968f29a

Please sign in to comment.