From a645c374b3a9bea66851d421baf030810a8fc7e4 Mon Sep 17 00:00:00 2001 From: Rui Vieira Date: Sun, 21 Jul 2024 23:01:36 +0100 Subject: [PATCH] feat: Add SSL context loading Signed-off-by: Rui Vieira (cherry picked from commit ed8161ab449dd782f5cfb892eb7584edadba6576) --- .../com/ibm/watson/modelmesh/ModelMesh.java | 41 ++++++++++++++++--- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index aaabaa53..d20cfe8c 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -101,10 +101,8 @@ import javax.annotation.concurrent.GuardedBy; import javax.net.ssl.SSLContext; -import java.io.File; -import java.io.IOException; -import java.io.InterruptedIOException; -import java.io.UncheckedIOException; +import javax.net.ssl.TrustManagerFactory; +import java.io.*; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; import java.lang.management.MemoryUsage; @@ -112,6 +110,7 @@ import java.lang.reflect.Method; import java.net.URI; import java.nio.channels.ClosedByInterruptException; +import java.security.KeyStore; import java.security.NoSuchAlgorithmException; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -432,6 +431,34 @@ public abstract class ModelMesh extends ThriftService } } + private static final String SSL_TRUSTSTORE_PATH_PROPERTY = "watson.ssl.truststore.path"; + private static final String SSL_TRUSTSTORE_PASSWORD_PROPERTY = "watson.ssl.truststore.password"; + + private static SSLContext sslContext = null; + + private static SSLContext loadSSLContext() throws Exception { + if (sslContext == null) { + final String trustStorePath = System.getProperty(SSL_TRUSTSTORE_PATH_PROPERTY); + final String trustStorePassword = System.getProperty(SSL_TRUSTSTORE_PASSWORD_PROPERTY); + + if (trustStorePath == null || trustStorePassword == null) { + throw new IllegalArgumentException("Truststore settings not found in system properties"); + } + + final KeyStore trustStore = KeyStore.getInstance("JKS"); + try (FileInputStream trustStoreStream = new FileInputStream(trustStorePath)) { + trustStore.load(trustStoreStream, trustStorePassword.toCharArray()); + } + + final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(trustStore); + + sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, trustManagerFactory.getTrustManagers(), null); + } + return sslContext; + } + private PayloadProcessor initPayloadProcessor() { String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null); logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions); @@ -445,14 +472,16 @@ private PayloadProcessor initPayloadProcessor() { String modelId = uri.getQuery(); String method = uri.getFragment(); if ("http".equals(processorName)) { + logger.info("Initializing HTTP payload processor"); processor = new RemotePayloadProcessor(uri); } else if ("https".equals(processorName)) { SSLContext sslContext; try { - sslContext = SSLContext.getDefault(); - } catch (NoSuchAlgorithmException missingAlgorithmException) { + sslContext = loadSSLContext(); + } catch (Exception missingAlgorithmException) { throw new UncheckedIOException(new IOException(missingAlgorithmException)); } + logger.info("Initializing HTTPS payload processor"); processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters()); } else if ("logger".equals(processorName)) { processor = new LoggingPayloadProcessor();