Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SSL context/params in RPP for HTTPS #146

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,19 @@
import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList;

import javax.annotation.concurrent.GuardedBy;
import javax.net.ssl.SSLContext;
import java.io.File;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.UncheckedIOException;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.channels.ClosedByInterruptException;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;
Expand Down Expand Up @@ -431,7 +435,7 @@ public abstract class ModelMesh extends ThriftService
private PayloadProcessor initPayloadProcessor() {
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) {
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
try {
Expand All @@ -442,6 +446,14 @@ private PayloadProcessor initPayloadProcessor() {
String method = uri.getFragment();
if ("http".equals(processorName)) {
processor = new RemotePayloadProcessor(uri);
} else if ("https".equals(processorName)) {
SSLContext sslContext;
try {
sslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException missingAlgorithmException) {
throw new UncheckedIOException(new IOException(missingAlgorithmException));
}
processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters());
} else if ("logger".equals(processorName)) {
processor = new LoggingPayloadProcessor();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.grpc.Metadata;
Expand All @@ -42,11 +44,27 @@ public class RemotePayloadProcessor implements PayloadProcessor {

private final URI uri;

private final SSLContext sslContext;
private final SSLParameters sslParameters;

private final HttpClient client;

public RemotePayloadProcessor(URI uri) {
this(uri, null, null);
}

public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) {
this.uri = uri;
this.client = HttpClient.newHttpClient();
this.sslContext = sslContext;
this.sslParameters = sslParameters;
if (sslContext != null && sslParameters != null) {
this.client = HttpClient.newBuilder()
.sslContext(sslContext)
.sslParameters(sslParameters)
.build();
} else {
this.client = HttpClient.newHttpClient();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,55 @@

package com.ibm.watson.modelmesh.payload;

import java.io.IOException;
import java.net.URI;
import java.security.NoSuchAlgorithmException;

import io.grpc.Metadata;
import io.grpc.Status;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import static org.junit.jupiter.api.Assertions.assertFalse;

class RemotePayloadProcessorTest {

void testDestinationUnreachable() throws IOException {
URI uri = URI.create("http://this-does-not-exist:123");
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}

@Test
void testDestinationUnreachable() {
RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123"));
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
void testDestinationUnreachableHTTPS() throws IOException, NoSuchAlgorithmException {
URI uri = URI.create("https://this-does-not-exist:123");
SSLContext sslContext = SSLContext.getDefault();
SSLParameters sslParameters = sslContext.getDefaultSSLParameters();
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri, sslContext, sslParameters)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}
}
Loading