Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RHOAIENG-4963: ModelMesh should support TLS in payload processors #65

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,18 @@
import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList;

import javax.annotation.concurrent.GuardedBy;
import java.io.File;
import java.io.InterruptedIOException;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import java.io.*;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.channels.ClosedByInterruptException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;
Expand Down Expand Up @@ -428,10 +431,38 @@ public abstract class ModelMesh extends ThriftService
}
}

private static final String SSL_TRUSTSTORE_PATH_PROPERTY = "watson.ssl.truststore.path";
private static final String SSL_TRUSTSTORE_PASSWORD_PROPERTY = "watson.ssl.truststore.password";

private static SSLContext sslContext = null;

private static SSLContext loadSSLContext() throws Exception {
if (sslContext == null) {
final String trustStorePath = System.getProperty(SSL_TRUSTSTORE_PATH_PROPERTY);
final String trustStorePassword = System.getProperty(SSL_TRUSTSTORE_PASSWORD_PROPERTY);

if (trustStorePath == null || trustStorePassword == null) {
throw new IllegalArgumentException("Truststore settings not found in system properties");
}

final KeyStore trustStore = KeyStore.getInstance("JKS");
try (FileInputStream trustStoreStream = new FileInputStream(trustStorePath)) {
trustStore.load(trustStoreStream, trustStorePassword.toCharArray());
}

final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(trustStore);

sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustManagerFactory.getTrustManagers(), null);
}
return sslContext;
}

private PayloadProcessor initPayloadProcessor() {
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) {
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
try {
Expand All @@ -441,7 +472,17 @@ private PayloadProcessor initPayloadProcessor() {
String modelId = uri.getQuery();
String method = uri.getFragment();
if ("http".equals(processorName)) {
logger.info("Initializing HTTP payload processor");
processor = new RemotePayloadProcessor(uri);
} else if ("https".equals(processorName)) {
SSLContext sslContext;
try {
sslContext = loadSSLContext();
} catch (Exception missingAlgorithmException) {
throw new UncheckedIOException(new IOException(missingAlgorithmException));
}
logger.info("Initializing HTTPS payload processor");
processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters());
} else if ("logger".equals(processorName)) {
processor = new LoggingPayloadProcessor();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.grpc.Metadata;
Expand All @@ -42,11 +44,27 @@ public class RemotePayloadProcessor implements PayloadProcessor {

private final URI uri;

private final SSLContext sslContext;
private final SSLParameters sslParameters;

private final HttpClient client;

public RemotePayloadProcessor(URI uri) {
this(uri, null, null);
}

public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) {
this.uri = uri;
this.client = HttpClient.newHttpClient();
this.sslContext = sslContext;
this.sslParameters = sslParameters;
if (sslContext != null && sslParameters != null) {
this.client = HttpClient.newBuilder()
.sslContext(sslContext)
.sslParameters(sslParameters)
.build();
} else {
this.client = HttpClient.newHttpClient();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,55 @@

package com.ibm.watson.modelmesh.payload;

import java.io.IOException;
import java.net.URI;
import java.security.NoSuchAlgorithmException;

import io.grpc.Metadata;
import io.grpc.Status;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import static org.junit.jupiter.api.Assertions.assertFalse;

class RemotePayloadProcessorTest {

void testDestinationUnreachable() throws IOException {
URI uri = URI.create("http://this-does-not-exist:123");
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}

@Test
void testDestinationUnreachable() {
RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123"));
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
void testDestinationUnreachableHTTPS() throws IOException, NoSuchAlgorithmException {
URI uri = URI.create("https://this-does-not-exist:123");
SSLContext sslContext = SSLContext.getDefault();
SSLParameters sslParameters = sslContext.getDefaultSSLParameters();
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri, sslContext, sslParameters)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}
}
Loading