diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index f8779651..d20cfe8c 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -100,8 +100,9 @@ import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; import javax.annotation.concurrent.GuardedBy; -import java.io.File; -import java.io.InterruptedIOException; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import java.io.*; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; import java.lang.management.MemoryUsage; @@ -109,6 +110,8 @@ import java.lang.reflect.Method; import java.net.URI; import java.nio.channels.ClosedByInterruptException; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.*; @@ -428,10 +431,38 @@ public abstract class ModelMesh extends ThriftService } } + private static final String SSL_TRUSTSTORE_PATH_PROPERTY = "watson.ssl.truststore.path"; + private static final String SSL_TRUSTSTORE_PASSWORD_PROPERTY = "watson.ssl.truststore.password"; + + private static SSLContext sslContext = null; + + private static SSLContext loadSSLContext() throws Exception { + if (sslContext == null) { + final String trustStorePath = System.getProperty(SSL_TRUSTSTORE_PATH_PROPERTY); + final String trustStorePassword = System.getProperty(SSL_TRUSTSTORE_PASSWORD_PROPERTY); + + if (trustStorePath == null || trustStorePassword == null) { + throw new IllegalArgumentException("Truststore settings not found in system properties"); + } + + final KeyStore trustStore = KeyStore.getInstance("JKS"); + try (FileInputStream trustStoreStream = new FileInputStream(trustStorePath)) { + trustStore.load(trustStoreStream, trustStorePassword.toCharArray()); + } + + final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(trustStore); + + sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, trustManagerFactory.getTrustManagers(), null); + } + return sslContext; + } + private PayloadProcessor initPayloadProcessor() { String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null); logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions); - if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) { + if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) { List payloadProcessors = new ArrayList<>(); for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) { try { @@ -441,7 +472,17 @@ private PayloadProcessor initPayloadProcessor() { String modelId = uri.getQuery(); String method = uri.getFragment(); if ("http".equals(processorName)) { + logger.info("Initializing HTTP payload processor"); processor = new RemotePayloadProcessor(uri); + } else if ("https".equals(processorName)) { + SSLContext sslContext; + try { + sslContext = loadSSLContext(); + } catch (Exception missingAlgorithmException) { + throw new UncheckedIOException(new IOException(missingAlgorithmException)); + } + logger.info("Initializing HTTPS payload processor"); + processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters()); } else if ("logger".equals(processorName)) { processor = new LoggingPayloadProcessor(); } diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java index 23c2fba1..12a64f1f 100644 --- a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java +++ b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java @@ -23,6 +23,8 @@ import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import com.fasterxml.jackson.databind.ObjectMapper; import io.grpc.Metadata; @@ -42,11 +44,27 @@ public class RemotePayloadProcessor implements PayloadProcessor { private final URI uri; + private final SSLContext sslContext; + private final SSLParameters sslParameters; + private final HttpClient client; public RemotePayloadProcessor(URI uri) { + this(uri, null, null); + } + + public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) { this.uri = uri; - this.client = HttpClient.newHttpClient(); + this.sslContext = sslContext; + this.sslParameters = sslParameters; + if (sslContext != null && sslParameters != null) { + this.client = HttpClient.newBuilder() + .sslContext(sslContext) + .sslParameters(sslParameters) + .build(); + } else { + this.client = HttpClient.newHttpClient(); + } } @Override diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java index ec08ea0a..a8da3c4c 100644 --- a/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java @@ -16,7 +16,9 @@ package com.ibm.watson.modelmesh.payload; +import java.io.IOException; import java.net.URI; +import java.security.NoSuchAlgorithmException; import io.grpc.Metadata; import io.grpc.Status; @@ -24,22 +26,45 @@ import io.netty.buffer.Unpooled; import org.junit.jupiter.api.Test; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + import static org.junit.jupiter.api.Assertions.assertFalse; class RemotePayloadProcessorTest { + void testDestinationUnreachable() throws IOException { + URI uri = URI.create("http://this-does-not-exist:123"); + try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) { + String id = "123"; + String modelId = "456"; + String method = "predict"; + Status kind = Status.INVALID_ARGUMENT; + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); + ByteBuf data = Unpooled.buffer(4); + Payload payload = new Payload(id, modelId, method, metadata, data, kind); + assertFalse(remotePayloadProcessor.process(payload)); + } + } + @Test - void testDestinationUnreachable() { - RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123")); - String id = "123"; - String modelId = "456"; - String method = "predict"; - Status kind = Status.INVALID_ARGUMENT; - Metadata metadata = new Metadata(); - metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); - ByteBuf data = Unpooled.buffer(4); - Payload payload = new Payload(id, modelId, method, metadata, data, kind); - assertFalse(remotePayloadProcessor.process(payload)); + void testDestinationUnreachableHTTPS() throws IOException, NoSuchAlgorithmException { + URI uri = URI.create("https://this-does-not-exist:123"); + SSLContext sslContext = SSLContext.getDefault(); + SSLParameters sslParameters = sslContext.getDefaultSSLParameters(); + try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri, sslContext, sslParameters)) { + String id = "123"; + String modelId = "456"; + String method = "predict"; + Status kind = Status.INVALID_ARGUMENT; + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); + ByteBuf data = Unpooled.buffer(4); + Payload payload = new Payload(id, modelId, method, metadata, data, kind); + assertFalse(remotePayloadProcessor.process(payload)); + } } } \ No newline at end of file