diff --git a/build.gradle b/build.gradle index b6e02f175de..771ccfefb07 100644 --- a/build.gradle +++ b/build.gradle @@ -10,7 +10,7 @@ plugins { id 'maven-publish' id 'com.diffplug.spotless' version '6.12.0' id 'com.dorongold.task-tree' version '2.1.0' - id 'com.github.johnrengelman.shadow' version '6.1.0' apply false + id 'com.github.johnrengelman.shadow' version '7.1.2' apply false id 'com.github.spotbugs' version '4.8.0' apply false id 'org.gradle.test-retry' version '1.5.0' apply false id 'com.form.diff-coverage' version '0.9.5' apply false @@ -87,6 +87,7 @@ ext.libraries = [ grpcProtobuf: "io.grpc:grpc-protobuf:${grpcVersion}", grpcServices: "io.grpc:grpc-services:${grpcVersion}", grpcStub: "io.grpc:grpc-stub:${grpcVersion}", + grpcTesting: "io.grpc:grpc-testing:${grpcVersion}", hadoopCommon: "org.apache.hadoop:hadoop-common:${hadoopVersion}", hadoopHdfs: "org.apache.hadoop:hadoop-hdfs:${hadoopVersion}", httpAsyncClient: 'org.apache.httpcomponents:httpasyncclient:4.1.5', @@ -515,6 +516,12 @@ subprojects { value = 'COVEREDRATIO' minimum = threshold } + // Ignore generate files + afterEvaluate { + classDirectories.setFrom(files(classDirectories.files.collect { fileTree(dir: it, exclude: [ + '**/com/linkedin/venice/protocols/**', + ])})) + } } } } @@ -818,4 +825,4 @@ task verifyJdkVersion { gradle.taskGraph.whenReady { // Ensure the JDK version is verified before any other tasks verifyJdkVersion -} +} \ No newline at end of file diff --git a/clients/venice-admin-tool/build.gradle b/clients/venice-admin-tool/build.gradle index a6f46156a6e..431c8b54d08 100644 --- a/clients/venice-admin-tool/build.gradle +++ b/clients/venice-admin-tool/build.gradle @@ -44,6 +44,11 @@ jar { } } +shadowJar { + // Enable merging service files from different dependencies. Required to make gRPC based clients work. + mergeServiceFiles() +} + ext { jacocoCoverageThreshold = 0.00 } diff --git a/clients/venice-client/src/main/java/com/linkedin/venice/fastclient/transport/GrpcTransportClient.java b/clients/venice-client/src/main/java/com/linkedin/venice/fastclient/transport/GrpcTransportClient.java index 30029701f24..145fd370dae 100644 --- a/clients/venice-client/src/main/java/com/linkedin/venice/fastclient/transport/GrpcTransportClient.java +++ b/clients/venice-client/src/main/java/com/linkedin/venice/fastclient/transport/GrpcTransportClient.java @@ -21,10 +21,8 @@ import com.linkedin.venice.utils.concurrent.VeniceConcurrentHashMap; import io.grpc.ChannelCredentials; import io.grpc.Grpc; -import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.Status; -import io.grpc.TlsChannelCredentials; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.Arrays; @@ -74,7 +72,7 @@ public GrpcTransportClient(GrpcClientConfig grpcClientConfig) { this.port = port; this.serverGrpcChannels = new VeniceConcurrentHashMap<>(); this.stubCache = new VeniceConcurrentHashMap<>(); - this.channelCredentials = buildChannelCredentials(sslFactory); + this.channelCredentials = GrpcUtils.buildChannelCredentials(sslFactory); } @Override @@ -99,25 +97,6 @@ public void close() throws IOException { r2TransportClientForNonStorageOps.close(); } - @VisibleForTesting - ChannelCredentials buildChannelCredentials(SSLFactory sslFactory) { - // TODO: Evaluate if this needs to fail instead since it depends on plain text support on server - if (sslFactory == null) { - return InsecureChannelCredentials.create(); - } - - try { - TlsChannelCredentials.Builder tlsBuilder = TlsChannelCredentials.newBuilder() - .keyManager(GrpcUtils.getKeyManagers(sslFactory)) - .trustManager(GrpcUtils.getTrustManagers(sslFactory)); - return tlsBuilder.build(); - } catch (Exception e) { - throw new VeniceClientException( - "Failed to initialize SSL channel credentials for Venice gRPC Transport Client", - e); - } - } - @VisibleForTesting VeniceClientRequest buildVeniceClientRequest(String[] requestParts, byte[] requestBody, boolean isSingleGet) { VeniceClientRequest.Builder requestBuilder = VeniceClientRequest.newBuilder() diff --git a/clients/venice-client/src/test/java/com/linkedin/venice/fastclient/transport/GrpcTransportClientTest.java b/clients/venice-client/src/test/java/com/linkedin/venice/fastclient/transport/GrpcTransportClientTest.java index 02b228ce1b3..98577fad8e1 100644 --- a/clients/venice-client/src/test/java/com/linkedin/venice/fastclient/transport/GrpcTransportClientTest.java +++ b/clients/venice-client/src/test/java/com/linkedin/venice/fastclient/transport/GrpcTransportClientTest.java @@ -1,20 +1,27 @@ package com.linkedin.venice.fastclient.transport; -import static org.mockito.Mockito.*; -import static org.testng.Assert.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; import com.google.common.collect.ImmutableMap; import com.linkedin.r2.transport.common.Client; import com.linkedin.venice.HttpMethod; -import com.linkedin.venice.client.exceptions.VeniceClientException; import com.linkedin.venice.client.store.transport.TransportClient; import com.linkedin.venice.client.store.transport.TransportClientResponse; import com.linkedin.venice.fastclient.GrpcClientConfig; import com.linkedin.venice.protocols.VeniceClientRequest; import com.linkedin.venice.protocols.VeniceReadServiceGrpc; import com.linkedin.venice.protocols.VeniceServerResponse; -import com.linkedin.venice.security.SSLFactory; -import io.grpc.ChannelCredentials; import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -61,14 +68,6 @@ public void setUp() { grpcTransportClient = new GrpcTransportClient(mockClientConfig); } - @Test(expectedExceptions = VeniceClientException.class) - public void testBuildChannelCredentials() { - ChannelCredentials actualChannelCredentials = grpcTransportClient.buildChannelCredentials(null); - assertNotNull(actualChannelCredentials, "Null ssl factory should default to insecure channel credentials"); - - grpcTransportClient.buildChannelCredentials(mock(SSLFactory.class)); - } - @Test public void testBuildVeniceClientRequestForSingleGet() { VeniceClientRequest clientRequest = diff --git a/gradle/spotbugs/exclude.xml b/gradle/spotbugs/exclude.xml index 3d5a8dbf544..62537f3e67c 100644 --- a/gradle/spotbugs/exclude.xml +++ b/gradle/spotbugs/exclude.xml @@ -481,8 +481,10 @@ - - + + + + diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/ConfigKeys.java b/internal/venice-common/src/main/java/com/linkedin/venice/ConfigKeys.java index 52436d59c6d..639f0b03b78 100644 --- a/internal/venice-common/src/main/java/com/linkedin/venice/ConfigKeys.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/ConfigKeys.java @@ -225,6 +225,27 @@ private ConfigKeys() { // What tags to assign to a controller instance public static final String CONTROLLER_INSTANCE_TAG_LIST = "controller.instance.tag.list"; + /** + * Whether to enable gRPC server in controller or not. + */ + public static final String CONTROLLER_GRPC_SERVER_ENABLED = "controller.grpc.server.enabled"; + + /** + * A port for the controller to listen on for incoming requests. On this port, the controller will + * server non-ssl requests. + */ + public static final String CONTROLLER_ADMIN_GRPC_PORT = "controller.admin.grpc.port"; + /** + * A port for the controller to listen on for incoming requests. On this port, the controller will + * only serve ssl requests. + */ + public static final String CONTROLLER_ADMIN_SECURE_GRPC_PORT = "controller.admin.secure.grpc.port"; + + /** + * Number of threads to use for the gRPC server in controller. + */ + public static final String CONTROLLER_GRPC_SERVER_THREAD_COUNT = "controller.grpc.server.thread.count"; + /** List of forbidden admin paths */ public static final String CONTROLLER_DISABLED_ROUTES = "controller.cluster.disabled.routes"; diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/ControllerApiConstants.java b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/ControllerApiConstants.java index a755a73d9ed..5289dd7aa0d 100644 --- a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/ControllerApiConstants.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/ControllerApiConstants.java @@ -7,7 +7,11 @@ public class ControllerApiConstants { public static final String SOURCE_GRID_FABRIC = "source_grid_fabric"; public static final String BATCH_JOB_HEARTBEAT_ENABLED = "batch_job_heartbeat_enabled"; - public static final String NAME = "store_name"; + public static final String STORE_NAME = "store_name"; + /** + * @deprecated Use {@link #STORE_NAME} instead. + */ + public static final String NAME = STORE_NAME; public static final String STORE_PARTITION = "store_partition"; public static final String STORE_VERSION = "store_version"; public static final String OWNER = "owner"; diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/LeaderControllerResponse.java b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/LeaderControllerResponse.java index 7332f1581b3..443025578eb 100644 --- a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/LeaderControllerResponse.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/LeaderControllerResponse.java @@ -5,6 +5,8 @@ public class LeaderControllerResponse private String cluster; private String url; private String secureUrl = null; + private String grpcUrl = null; + private String secureGrpcUrl = null; public String getCluster() { return cluster; @@ -29,4 +31,20 @@ public String getSecureUrl() { public void setSecureUrl(String url) { this.secureUrl = url; } + + public void setGrpcUrl(String url) { + this.grpcUrl = url; + } + + public String getGrpcUrl() { + return grpcUrl; + } + + public void setSecureGrpcUrl(String url) { + this.secureGrpcUrl = url; + } + + public String getSecureGrpcUrl() { + return secureGrpcUrl; + } } diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/ClusterDiscoveryRequest.java b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/ClusterDiscoveryRequest.java new file mode 100644 index 00000000000..cbe45126369 --- /dev/null +++ b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/ClusterDiscoveryRequest.java @@ -0,0 +1,9 @@ +package com.linkedin.venice.controllerapi.request; + +public class ClusterDiscoveryRequest extends ControllerRequest { + private static final String CLUSTER_NAME_PLACEHOLDER = "UNKNOWN"; + + public ClusterDiscoveryRequest(String storeName) { + super(CLUSTER_NAME_PLACEHOLDER, storeName); + } +} diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/ControllerRequest.java b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/ControllerRequest.java new file mode 100644 index 00000000000..248213536b3 --- /dev/null +++ b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/ControllerRequest.java @@ -0,0 +1,42 @@ +package com.linkedin.venice.controllerapi.request; + +import static com.linkedin.venice.controllerapi.ControllerApiConstants.CLUSTER; +import static com.linkedin.venice.controllerapi.ControllerApiConstants.STORE_NAME; + + +/** + * Base class for request objects used in controller endpoints. + * + * Extend this class to ensure required parameters are validated in the constructor of the extending class. + * This class is intended for use on both the client and server sides. + * All required parameters should be passed to and validated within the constructor of the extending class. + */ +public class ControllerRequest { + protected String clusterName; + protected String storeName; + + public ControllerRequest(String clusterName) { + this.clusterName = validateParam(clusterName, CLUSTER); + this.storeName = null; + } + + public ControllerRequest(String clusterName, String storeName) { + this.clusterName = validateParam(clusterName, CLUSTER); + this.storeName = validateParam(storeName, STORE_NAME); + } + + public String getClusterName() { + return clusterName; + } + + public String getStoreName() { + return storeName; + } + + public static String validateParam(String param, String paramName) { + if (param == null || param.isEmpty()) { + throw new IllegalArgumentException("The request is missing the " + paramName + ", which is a mandatory field."); + } + return param; + } +} diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/CreateNewStoreRequest.java b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/CreateNewStoreRequest.java new file mode 100644 index 00000000000..f96fc2968f0 --- /dev/null +++ b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/request/CreateNewStoreRequest.java @@ -0,0 +1,54 @@ +package com.linkedin.venice.controllerapi.request; + +/** + * Represents a request to create a new store in the specified Venice cluster with the provided parameters. + * This class encapsulates all necessary details for the creation of a store, including its name, owner, + * schema definitions, and access permissions. + */ +public class CreateNewStoreRequest extends ControllerRequest { + public static final String DEFAULT_STORE_OWNER = ""; + + private final String owner; + private final String keySchema; + private final String valueSchema; + private final boolean isSystemStore; + + // a JSON string representing the access permissions for the store + private final String accessPermissions; + + public CreateNewStoreRequest( + String clusterName, + String storeName, + String owner, + String keySchema, + String valueSchema, + String accessPermissions, + boolean isSystemStore) { + super(clusterName, storeName); + this.keySchema = validateParam(keySchema, "Key schema"); + this.valueSchema = validateParam(valueSchema, "Value schema"); + this.owner = owner == null ? DEFAULT_STORE_OWNER : owner; + this.accessPermissions = accessPermissions; + this.isSystemStore = isSystemStore; + } + + public String getOwner() { + return owner; + } + + public String getKeySchema() { + return keySchema; + } + + public String getValueSchema() { + return valueSchema; + } + + public String getAccessPermissions() { + return accessPermissions; + } + + public boolean isSystemStore() { + return isSystemStore; + } +} diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/transport/GrpcRequestResponseConverter.java b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/transport/GrpcRequestResponseConverter.java new file mode 100644 index 00000000000..70bc260888d --- /dev/null +++ b/internal/venice-common/src/main/java/com/linkedin/venice/controllerapi/transport/GrpcRequestResponseConverter.java @@ -0,0 +1,80 @@ +package com.linkedin.venice.controllerapi.transport; + +import com.google.protobuf.Any; +import com.linkedin.venice.client.exceptions.VeniceClientException; +import com.linkedin.venice.controllerapi.ControllerResponse; +import com.linkedin.venice.controllerapi.request.ControllerRequest; +import com.linkedin.venice.protocols.controller.ClusterStoreGrpcInfo; +import com.linkedin.venice.protocols.controller.ControllerGrpcErrorType; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcErrorInfo; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.StatusProto; +import io.grpc.stub.StreamObserver; + + +public class GrpcRequestResponseConverter { + public static ClusterStoreGrpcInfo getClusterStoreGrpcInfo(ControllerResponse response) { + ClusterStoreGrpcInfo.Builder builder = ClusterStoreGrpcInfo.newBuilder(); + if (response.getCluster() != null) { + builder.setClusterName(response.getCluster()); + } + if (response.getName() != null) { + builder.setStoreName(response.getName()); + } + return builder.build(); + } + + public static ClusterStoreGrpcInfo getClusterStoreGrpcInfo(ControllerRequest request) { + ClusterStoreGrpcInfo.Builder builder = ClusterStoreGrpcInfo.newBuilder(); + if (request.getClusterName() != null) { + builder.setClusterName(request.getClusterName()); + } + if (request.getStoreName() != null) { + builder.setStoreName(request.getStoreName()); + } + return builder.build(); + } + + public static void sendErrorResponse( + Code code, + ControllerGrpcErrorType errorType, + Exception e, + String clusterName, + String storeName, + StreamObserver responseObserver) { + VeniceControllerGrpcErrorInfo.Builder errorInfoBuilder = + VeniceControllerGrpcErrorInfo.newBuilder().setStatusCode(code.value()).setErrorType(errorType); + if (e.getMessage() != null) { + errorInfoBuilder.setErrorMessage(e.getMessage()); + } + if (clusterName != null) { + errorInfoBuilder.setClusterName(clusterName); + } + if (storeName != null) { + errorInfoBuilder.setStoreName(storeName); + } + // Wrap the error info into a com.google.rpc.Status message + com.google.rpc.Status status = + com.google.rpc.Status.newBuilder().setCode(code.value()).addDetails(Any.pack(errorInfoBuilder.build())).build(); + + // Send the error response + responseObserver.onError(StatusProto.toStatusRuntimeException(status)); + } + + public static VeniceControllerGrpcErrorInfo parseControllerGrpcError(StatusRuntimeException e) { + com.google.rpc.Status status = StatusProto.fromThrowable(e); + if (status != null) { + for (com.google.protobuf.Any detail: status.getDetailsList()) { + if (detail.is(VeniceControllerGrpcErrorInfo.class)) { + try { + return detail.unpack(VeniceControllerGrpcErrorInfo.class); + } catch (Exception unpackException) { + throw new VeniceClientException("Failed to unpack error details: " + unpackException.getMessage()); + } + } + } + } + throw new VeniceClientException("An unknown gRPC error occurred. Error code: " + Code.UNKNOWN.name()); + } +} diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/grpc/GrpcUtils.java b/internal/venice-common/src/main/java/com/linkedin/venice/grpc/GrpcUtils.java index f2fcb1c0e7a..818d7056713 100644 --- a/internal/venice-common/src/main/java/com/linkedin/venice/grpc/GrpcUtils.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/grpc/GrpcUtils.java @@ -1,13 +1,17 @@ package com.linkedin.venice.grpc; import com.linkedin.venice.acl.handler.AccessResult; +import com.linkedin.venice.client.exceptions.VeniceClientException; import com.linkedin.venice.exceptions.VeniceException; import com.linkedin.venice.security.SSLConfig; import com.linkedin.venice.security.SSLFactory; import com.linkedin.venice.utils.SslUtils; +import io.grpc.ChannelCredentials; import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ServerCall; import io.grpc.Status; +import io.grpc.TlsChannelCredentials; import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; @@ -97,4 +101,22 @@ private static KeyStore loadStore(String path, char[] password, String type) } return keyStore; } + + public static ChannelCredentials buildChannelCredentials(SSLFactory sslFactory) { + // TODO: Evaluate if this needs to fail instead since it depends on plain text support on server + if (sslFactory == null) { + return InsecureChannelCredentials.create(); + } + + try { + TlsChannelCredentials.Builder tlsBuilder = TlsChannelCredentials.newBuilder() + .keyManager(GrpcUtils.getKeyManagers(sslFactory)) + .trustManager(GrpcUtils.getTrustManagers(sslFactory)); + return tlsBuilder.build(); + } catch (Exception e) { + throw new VeniceClientException( + "Failed to initialize SSL channel credentials for Venice gRPC Transport Client", + e); + } + } } diff --git a/services/venice-server/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServer.java b/internal/venice-common/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServer.java similarity index 77% rename from services/venice-server/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServer.java rename to internal/venice-common/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServer.java index 3e6fc3546a1..c8c01193651 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServer.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServer.java @@ -30,14 +30,11 @@ public class VeniceGrpcServer { private final VeniceGrpcServerConfig config; public VeniceGrpcServer(VeniceGrpcServerConfig config) { - port = config.getPort(); - sslFactory = config.getSslFactory(); - executor = config.getExecutor(); - + this.port = config.getPort(); + this.sslFactory = config.getSslFactory(); + this.executor = config.getExecutor(); this.config = config; - initServerCredentials(); - server = Grpc.newServerBuilderForPort(config.getPort(), credentials) .executor(executor) // TODO: experiment with different executors for best performance .addService(ServerInterceptors.intercept(config.getService(), config.getInterceptors())) @@ -47,13 +44,13 @@ public VeniceGrpcServer(VeniceGrpcServerConfig config) { private void initServerCredentials() { if (sslFactory == null && config.getCredentials() == null) { - LOGGER.info("Creating gRPC server with insecure credentials"); + LOGGER.info("Creating gRPC server with insecure credentials on port: {}", port); credentials = InsecureServerCredentials.create(); return; } if (config.getCredentials() != null) { - LOGGER.info("Creating gRPC server with custom credentials"); + LOGGER.debug("Creating gRPC server with custom credentials"); credentials = config.getCredentials(); return; } @@ -74,9 +71,14 @@ private void initServerCredentials() { public void start() throws VeniceException { try { server.start(); + LOGGER.info( + "Started gRPC server for service: {} on port: {} isSecure: {}", + config.getService().getClass().getSimpleName(), + port, + isSecure()); } catch (IOException exception) { LOGGER.error( - "Failed to start gRPC Server for service {} on port {}", + "Failed to start gRPC server for service: {} on port: {}", config.getService().getClass().getSimpleName(), port, exception); @@ -84,17 +86,30 @@ public void start() throws VeniceException { } } - public boolean isShutdown() { - return server.isShutdown(); + public boolean isRunning() { + return !server.isShutdown(); } public boolean isTerminated() { return server.isTerminated(); } + private boolean isSecure() { + return !(credentials instanceof InsecureServerCredentials); + } + public void stop() { + LOGGER.info( + "Shutting down gRPC server for service: {} on port: {} isSecure: {}", + config.getService().getClass().getSimpleName(), + port, + isSecure()); if (server != null && !server.isShutdown()) { server.shutdown(); } } + + public Server getServer() { + return server; + } } diff --git a/services/venice-server/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServerConfig.java b/internal/venice-common/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServerConfig.java similarity index 89% rename from services/venice-server/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServerConfig.java rename to internal/venice-common/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServerConfig.java index c0a2f5b9841..96b271f1f27 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServerConfig.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/grpc/VeniceGrpcServerConfig.java @@ -112,18 +112,18 @@ public VeniceGrpcServerConfig build() { private void verifyAndAddDefaults() { if (port == null) { - throw new IllegalArgumentException("Port must be set"); + throw new IllegalArgumentException("Port value is required to create the gRPC server but was not provided."); } if (service == null) { - throw new IllegalArgumentException("Service must be set"); + throw new IllegalArgumentException("A non-null gRPC service instance is required to create the server."); + } + if (numThreads <= 0 && executor == null) { + throw new IllegalArgumentException( + "gRPC server creation requires a valid number of threads (numThreads > 0) or a non-null executor."); } if (interceptors == null) { interceptors = Collections.emptyList(); } - if (numThreads <= 0 && executor == null) { - throw new IllegalArgumentException("Either numThreads or executor must be set"); - } - if (executor == null) { executor = Executors.newFixedThreadPool(numThreads); } diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/meta/Instance.java b/internal/venice-common/src/main/java/com/linkedin/venice/meta/Instance.java index 7b0aa5c7a3f..29a848d5029 100644 --- a/internal/venice-common/src/main/java/com/linkedin/venice/meta/Instance.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/meta/Instance.java @@ -34,23 +34,38 @@ public class Instance { private final String url; private final String sUrl; + private final int grpcPort; + private final int grpcSslPort; + private final String grpcUrl; + private final String grpcSslUrl; + // TODO: generate nodeId from host and port, should be "host_port", or generate host and port from id. public Instance(String nodeId, String host, int port) { - this(nodeId, host, port, port); + this(nodeId, host, port, port, -1, -1); + } + + public Instance(String nodeId, String host, int port, int grpcPort, int grpcSslPort) { + this(nodeId, host, port, port, grpcPort, grpcSslPort); } public Instance( @JsonProperty("nodeId") String nodeId, @JsonProperty("host") String host, @JsonProperty("port") int port, - @JsonProperty("sslPort") int sslPort) { + @JsonProperty("sslPort") int sslPort, + @JsonProperty("grpcPort") int grpcPort, + @JsonProperty("grpcSslPort") int grpcSslPort) { this.nodeId = nodeId; this.host = host; - validatePort("port", port); - this.port = port; + this.port = validatePort("port", port); this.sslPort = sslPort; this.url = "http://" + host + ":" + port + "/"; this.sUrl = "https://" + host + ":" + sslPort + "/"; + + this.grpcPort = grpcPort; + this.grpcSslPort = grpcSslPort; + this.grpcUrl = host + ":" + grpcPort; + this.grpcSslUrl = host + ":" + grpcSslPort; } public static Instance fromHostAndPort(String hostName, int port) { @@ -91,6 +106,22 @@ public int getSslPort() { return sslPort; } + public int getGrpcPort() { + return grpcPort; + } + + public int getGrpcSslPort() { + return grpcSslPort; + } + + public String getGrpcUrl() { + return grpcUrl; + } + + public String getGrpcSslUrl() { + return grpcSslUrl; + } + /*** * Convenience method for getting a host and port based url. * Wraps IPv6 host strings in square brackets @@ -111,10 +142,11 @@ public String getUrl() { return getUrl(false); } - private void validatePort(String name, int port) { + private int validatePort(String name, int port) { if (port < 0 || port > 65535) { throw new IllegalArgumentException("Invalid " + name + ": " + port); } + return port; } // Autogen except for .toLowerCase() diff --git a/internal/venice-common/src/main/proto/VeniceControllerGrpcService.proto b/internal/venice-common/src/main/proto/VeniceControllerGrpcService.proto new file mode 100644 index 00000000000..dbd2146e0b7 --- /dev/null +++ b/internal/venice-common/src/main/proto/VeniceControllerGrpcService.proto @@ -0,0 +1,224 @@ +syntax = 'proto3'; +package com.linkedin.venice.protocols.controller; + +import "google/rpc/status.proto"; +import "google/rpc/error_details.proto"; +import "VeniceStore.proto"; + +option java_multiple_files = true; + + +service VeniceControllerGrpcService { + // ClusterDiscovery + rpc discoverClusterForStore(DiscoverClusterGrpcRequest) returns (DiscoverClusterGrpcResponse) {} + + // ControllerRoutes + rpc getLeaderController(LeaderControllerGrpcRequest) returns (LeaderControllerGrpcResponse); + + // CreateStore + rpc createStore(CreateStoreGrpcRequest) returns (CreateStoreGrpcResponse) {} +} + +message EmptyPushGrpcRequest { + ClusterStoreGrpcInfo clusterStoreInfo = 1; + string pushJobId = 2; +} + +message EmptyPushGrpcResponse { + ClusterStoreGrpcInfo clusterStoreInfo = 1; + int32 version = 2; + int32 partitions = 3; + int32 replicas = 4; + string pubSubTopic = 5; + string pubSubBootstrapServers = 6; + bool enableSSL = 7; + CompressionStrategyGrpc compressionStrategy = 8; + string partitionerClass = 9; + map partitionerParams = 10; + bool daVinciPushStatusStoreEnabled = 11; + optional string pubSubSourceRegion = 12; +} + +message GetStoreGrpcRequest { + ClusterStoreGrpcInfo clusterStoreInfo = 1; +} + +message GetStoreGrpcResponse { + ClusterStoreGrpcInfo clusterStoreInfo = 1; + StoreInfoGrpc storeInfo = 2; +} + +message ListStoresGrpcRequest { + string clusterName = 1; + optional bool includeSystemStores = 2; + optional string storeConfigNameFilter = 3; + optional string storeConfigValueFilter = 4; +} + +message ListStoresGrpcResponse { + string clusterName = 1; + repeated string storeName = 2; +} + +message ListBootstrappingVersionsGrpcRequest { + string clusterName = 1; +} + +message BootstrappingVersion { + string storeVersionName = 1; + string versionStatus = 2; +} + +message ListBootstrappingVersionsGrpcResponse { + string clusterName = 1; + repeated BootstrappingVersion bootstrappingVersions = 2; +} + +message DiscoverClusterGrpcRequest { + string storeName = 1; +} + +message DiscoverClusterGrpcResponse { + string clusterName = 1; + string storeName = 2; + string d2Service = 3; + string serverD2Service = 4; + string zkAddress = 5; + string pubSubBootstrapServers = 6; +} + +message AdminTopicMetadataGrpcRequest { + string clusterName = 1; + optional string storeName = 2; +} + +message AdminTopicMetadataGrpcResponse { + string clusterName = 1; + optional string storeName = 2; + int64 executionId = 3; + optional int64 offset = 4; + optional int64 upstreamOffset = 5; +} + +message UpdateAdminTopicMetadataGrpcRequest { + string clusterName = 1; + optional string storeName = 2; + int64 executionId = 3; + optional int64 offset = 4; + optional int64 upstreamOffset = 5; +} + +message UpdateAdminTopicMetadataGrpcResponse { + string clusterName = 1; + optional string storeName = 2; +} + +message AdminCommandExecutionStatusGrpcRequest { + string clusterName = 1; + int64 adminCommandExecutionId = 2; +} + +message AdminCommandExecutionStatusGrpcResponse { + string clusterName = 1; + int64 adminCommandExecutionId = 2; + string operation = 3; + string startTime = 4; + map fabricToExecutionStatusMap = 5; +} + +message LastSuccessfulAdminCommandExecutionGrpcRequest { + string clusterName = 1; +} + +message LastSuccessfulAdminCommandExecutionGrpcResponse { + string clusterName = 1; + int64 lastSuccessfulAdminCommandExecutionId = 2; +} + + +message ClusterStoreGrpcInfo { + string clusterName = 1; + string storeName = 2; +} + +message CreateStoreGrpcRequest { + ClusterStoreGrpcInfo clusterStoreInfo = 1; + string keySchema = 2; + string valueSchema = 3; + optional string owner = 4; + optional bool isSystemStore = 5; + optional string accessPermission = 6; +} + +message CreateStoreGrpcResponse { + ClusterStoreGrpcInfo clusterStoreInfo = 1; + string owner = 2; +} + +message UpdateAclForStoreGrpcRequest { + ClusterStoreGrpcInfo clusterStoreInfo = 1; + string accessPermissions = 3; +} + +message UpdateAclForStoreGrpcResponse { + ClusterStoreGrpcInfo clusterStoreInfo = 1; +} + +message GetAclForStoreGrpcRequest { + ClusterStoreGrpcInfo clusterStoreInfo = 1; +} + +message GetAclForStoreGrpcResponse { + ClusterStoreGrpcInfo clusterStoreInfo = 1; + string accessPermissions = 2; +} + +message DeleteAclForStoreGrpcRequest { + ClusterStoreGrpcInfo clusterStoreInfo = 1; +} + +message DeleteAclForStoreGrpcResponse { + ClusterStoreGrpcInfo clusterStoreInfo = 1; +} + +message CheckResourceCleanupForStoreCreationGrpcRequest { + ClusterStoreGrpcInfo clusterStoreInfo = 1; +} + +message CheckResourceCleanupForStoreCreationGrpcResponse { + ClusterStoreGrpcInfo clusterStoreInfo = 1; +} + +enum ControllerGrpcErrorType { + UNKNOWN = 0; + INCORRECT_CONTROLLER = 1; + INVALID_SCHEMA = 2; + INVALID_CONFIG = 3; + STORE_NOT_FOUND = 4; + SCHEMA_NOT_FOUND = 5; + CONNECTION_ERROR = 6; + GENERAL_ERROR = 7; + BAD_REQUEST = 8; + CONCURRENT_BATCH_PUSH = 9; + RESOURCE_STILL_EXISTS = 10; +} + +message VeniceControllerGrpcErrorInfo { + uint32 statusCode = 1; + string errorMessage = 2; + optional ControllerGrpcErrorType errorType = 3; + optional string clusterName = 4; + optional string storeName = 5; +} + +message LeaderControllerGrpcRequest { + string clusterName = 1; // The cluster name +} + +message LeaderControllerGrpcResponse { + string clusterName = 1; // The cluster name + string httpUrl = 2; // Leader controller URL + string httpsUrl = 3; // SSL-enabled leader controller URL + string grpcUrl = 4; // gRPC URL for leader controller + string secureGrpcUrl = 5; // Secure gRPC URL for leader controller +} diff --git a/internal/venice-common/src/main/proto/VeniceEchoService.proto b/internal/venice-common/src/main/proto/VeniceEchoService.proto new file mode 100644 index 00000000000..fe8a592b3a7 --- /dev/null +++ b/internal/venice-common/src/main/proto/VeniceEchoService.proto @@ -0,0 +1,29 @@ +syntax = 'proto3'; +package com.linkedin.venice.protocols; + +import "google/rpc/status.proto"; +import "google/rpc/error_details.proto"; +import "google/protobuf/descriptor.proto"; + +import "VeniceStore.proto"; + +option java_multiple_files = true; + +// Define a custom option for method types +extend google.protobuf.MethodOptions { + string methodType = 50001; // Assign a unique field number +} + +service VeniceEchoService { + rpc echo (VeniceEchoRequest) returns (VeniceEchoResponse) { + option (methodType) = "GET"; // Assign the method type + } +} + +message VeniceEchoRequest { + string message = 1; +} + +message VeniceEchoResponse { + string message = 1; +} \ No newline at end of file diff --git a/internal/venice-common/src/main/proto/VeniceStore.proto b/internal/venice-common/src/main/proto/VeniceStore.proto new file mode 100644 index 00000000000..9ffeed4d592 --- /dev/null +++ b/internal/venice-common/src/main/proto/VeniceStore.proto @@ -0,0 +1,172 @@ +syntax = 'proto3'; +package com.linkedin.venice.protocols.controller; + +import "google/rpc/status.proto"; +import "google/rpc/error_details.proto"; +import "google/protobuf/timestamp.proto"; // Import for timestamp handling + +option java_multiple_files = true; + +enum DataReplicationPolicyGrpc { + NON_AGGREGATE = 0; + AGGREGATE = 1; + NONE = 2; + ACTIVE_ACTIVE = 3; +} + +enum BufferReplayPolicyGrpc { + REWIND_FROM_EOP = 0; + REWIND_FROM_SOP = 1; +} + +message HybridStoreConfigGrpc { + int64 rewindTimeInSeconds = 1; + int64 offsetLagThresholdToGoOnline = 2; + int64 producerTimestampLagThresholdToGoOnlineInSeconds = 3; + DataReplicationPolicyGrpc dataReplicationPolicy = 4; + BufferReplayPolicyGrpc bufferReplayPolicy = 5; +} + +enum CompressionStrategyGrpc { + NO_OP = 0; + GZIP = 1; + ZSTD = 2; + ZSTD_WITH_DICT = 3; +} + +enum BackupStrategyGrpc { + KEEP_MIN_VERSIONS = 0; + DELETE_ON_NEW_PUSH_START = 1; +} + +message ETLStoreConfigGrpc { + string etledUserProxyAccount = 1; + bool regularVersionETLEnabled = 2; + bool futureVersionETLEnabled = 3; +} + +message PartitionerConfigGrpc { + string partitionerClass = 1; + map partitionerParams = 2; +} + +message ViewConfigGrpc { + string viewClassName = 1; + map viewParameters = 2; +} + +message DataRecoveryVersionConfigGrpc { + string dataRecoverySourceFabric = 1; + bool dataRecoveryComplete = 2; + int32 dataRecoverySourceVersionNumber = 3; +} + +enum PushTypeGrpc { + BATCH = 0; + STREAM_REPROCESSING = 1; + STREAM = 2; + INCREMENTAL = 3; +} + +enum VersionStatusGrpc { + NOT_CREATED = 0; + STARTED = 1; + PUSHED = 2; // Version has been pushed to Venice but not ready for reads (writes disabled). + ONLINE = 3; // Version is pushed and ready to serve read requests. + ERROR = 4; // Version encountered an error. + CREATED = 5; // Version is created and persisted, but not fully prepared yet. + PARTIALLY_ONLINE = 6; // Version is online in some regions, failed in others (parent version only). + KILLED = 7; // This version has been killed. +} + +// Protobuf message for Store Info +message StoreInfoGrpc { + string name = 1; + string owner = 2; + int32 partitionCount = 3; + int32 currentVersion = 4; + int32 reservedVersion = 5; + int64 lowWatermark = 6; + bool enableStoreWrites = 7; + bool enableStoreReads = 8; + int64 storageQuotaInByte = 9; + bool hybridStoreOverheadBypass = 10; + int64 readQuotaInCU = 11; + bool accessControlled = 12; + bool chunkingEnabled = 13; + bool rmdChunkingEnabled = 14; + bool singleGetRouterCacheEnabled = 15; + bool batchGetRouterCacheEnabled = 16; + int32 batchGetLimit = 17; + int32 largestUsedVersionNumber = 18; + bool incrementalPushEnabled = 19; + bool clientDecompressionEnabled = 20; + int32 numVersionsToPreserve = 21; + bool migrating = 22; + bool writeComputationEnabled = 23; + int32 replicationMetadataVersionId = 24; + bool readComputationEnabled = 25; + int32 bootstrapToOnlineTimeoutInHours = 26; + bool nativeReplicationEnabled = 27; + string pushStreamSourceAddress = 28; + bool schemaAutoRegisterFromPushJobEnabled = 29; + bool superSetSchemaAutoGenerationForReadComputeEnabled = 30; + int32 latestSuperSetValueSchemaId = 31; + bool hybridStoreDiskQuotaEnabled = 32; + int64 backupVersionRetentionMs = 33; + int32 replicationFactor = 34; + bool migrationDuplicateStore = 35; + string nativeReplicationSourceFabric = 36; + bool storeMetadataSystemStoreEnabled = 37; + bool storeMetaSystemStoreEnabled = 38; + bool daVinciPushStatusStoreEnabled = 39; + bool activeActiveReplicationEnabled = 40; + string kafkaBrokerUrl = 41; + bool storageNodeReadQuotaEnabled = 42; + int64 minCompactionLagSeconds = 43; + int64 maxCompactionLagSeconds = 44; + int32 maxRecordSizeBytes = 45; + int32 maxNearlineRecordSizeBytes = 46; + bool unusedSchemaDeletionEnabled = 47; + bool blobTransferEnabled = 48; + + map coloToCurrentVersions = 49; + repeated VersionGrpc versions = 50; + optional HybridStoreConfigGrpc hybridStoreConfig = 51; + CompressionStrategyGrpc compressionStrategy = 52; + BackupStrategyGrpc backupStrategy = 53; + optional ETLStoreConfigGrpc etlStoreConfig = 54; + PartitionerConfigGrpc partitionerConfig = 55; + map viewConfigs = 56; +} + +message VersionGrpc { + string storeName = 1; + int32 number = 2; + int64 createdTime = 3; + VersionStatusGrpc status = 4; + string pushJobId = 5; + CompressionStrategyGrpc compressionStrategy = 6; + string pushStreamSourceAddress = 7; + bool chunkingEnabled = 8; + bool rmdChunkingEnabled = 9; + PushTypeGrpc pushType = 10; + int32 partitionCount = 11; + PartitionerConfigGrpc partitionerConfig = 12; + int32 replicationFactor = 13; + string nativeReplicationSourceFabric = 14; + bool incrementalPushEnabled = 15; + bool separateRealTimeTopicEnabled = 16; + bool blobTransferEnabled = 17; + bool useVersionLevelIncrementalPushEnabled = 18; + optional HybridStoreConfigGrpc hybridConfig = 19; + bool useVersionLevelHybridConfig = 20; + bool activeActiveReplicationEnabled = 21; + int32 timestampMetadataVersionId = 22; + optional DataRecoveryVersionConfigGrpc dataRecoveryConfig = 23; + bool deferVersionSwap = 24; + map viewConfigs = 25; + int32 repushSourceVersion = 26; + bool versionSwapDeferred = 27; +} + diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/ClusterDiscoveryRequestTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/ClusterDiscoveryRequestTest.java new file mode 100644 index 00000000000..d4658820c32 --- /dev/null +++ b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/ClusterDiscoveryRequestTest.java @@ -0,0 +1,24 @@ +package com.linkedin.venice.controllerapi.request; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +import org.testng.annotations.Test; + + +public class ClusterDiscoveryRequestTest { + @Test + public void testClusterDiscoveryRequest() { + // Case 1: Store name is provided + ClusterDiscoveryRequest clusterDiscoveryRequest = new ClusterDiscoveryRequest("storeName"); + assertNotNull(clusterDiscoveryRequest.getStoreName()); + assertEquals(clusterDiscoveryRequest.getStoreName(), "storeName"); + + // Case 2: Store name is not provided + IllegalArgumentException exception = + expectThrows(IllegalArgumentException.class, () -> new ClusterDiscoveryRequest(null)); + assertTrue(exception.getMessage().contains("The request is missing the store_name")); + } +} diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/ControllerRequestTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/ControllerRequestTest.java new file mode 100644 index 00000000000..7bf49dce5ad --- /dev/null +++ b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/ControllerRequestTest.java @@ -0,0 +1,64 @@ +package com.linkedin.venice.controllerapi.request; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.expectThrows; + +import org.testng.annotations.Test; + + +public class ControllerRequestTest { + private static final String CLUSTER = "cluster"; + private static final String STORE_NAME = "store_name"; + + @Test + public void testValidInputs() { + // Test with only cluster name + ControllerRequest request1 = new ControllerRequest("testCluster"); + assertEquals("testCluster", request1.getClusterName()); + assertNull(request1.getStoreName()); + + // Test with cluster name and store name + ControllerRequest request2 = new ControllerRequest("testCluster", "testStore"); + assertEquals("testCluster", request2.getClusterName()); + assertEquals("testStore", request2.getStoreName()); + } + + @Test + public void testInvalidClusterName() { + Exception exception1 = expectThrows(IllegalArgumentException.class, () -> new ControllerRequest(null)); + assertEquals("The request is missing the cluster_name, which is a mandatory field.", exception1.getMessage()); + + Exception exception2 = expectThrows(IllegalArgumentException.class, () -> new ControllerRequest("")); + assertEquals("The request is missing the cluster_name, which is a mandatory field.", exception2.getMessage()); + + Exception exception3 = expectThrows(IllegalArgumentException.class, () -> new ControllerRequest(null, "testStore")); + assertEquals("The request is missing the cluster_name, which is a mandatory field.", exception3.getMessage()); + + Exception exception4 = expectThrows(IllegalArgumentException.class, () -> new ControllerRequest("", "testStore")); + assertEquals("The request is missing the cluster_name, which is a mandatory field.", exception4.getMessage()); + } + + @Test + public void testInvalidStoreName() { + Exception exception1 = + expectThrows(IllegalArgumentException.class, () -> new ControllerRequest("testCluster", null)); + assertEquals("The request is missing the store_name, which is a mandatory field.", exception1.getMessage()); + + Exception exception2 = expectThrows(IllegalArgumentException.class, () -> new ControllerRequest("testCluster", "")); + assertEquals("The request is missing the store_name, which is a mandatory field.", exception2.getMessage()); + } + + @Test + public void testValidateParam() { + assertEquals("validParam", ControllerRequest.validateParam("validParam", CLUSTER)); + + Exception exception1 = + expectThrows(IllegalArgumentException.class, () -> ControllerRequest.validateParam(null, CLUSTER)); + assertEquals("The request is missing the cluster, which is a mandatory field.", exception1.getMessage()); + + Exception exception2 = + expectThrows(IllegalArgumentException.class, () -> ControllerRequest.validateParam("", STORE_NAME)); + assertEquals("The request is missing the store_name, which is a mandatory field.", exception2.getMessage()); + } +} diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/CreateNewStoreRequestTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/CreateNewStoreRequestTest.java new file mode 100644 index 00000000000..1be0faff14a --- /dev/null +++ b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/request/CreateNewStoreRequestTest.java @@ -0,0 +1,120 @@ +package com.linkedin.venice.controllerapi.request; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +import org.testng.annotations.Test; + + +public class CreateNewStoreRequestTest { + private static final String CLUSTER_NAME = "test-cluster"; + private static final String STORE_NAME = "storeName"; + private static final String OWNER = "owner"; + private static final String KEY_SCHEMA = "int"; + private static final String VALUE_SCHEMA = "string"; + private static final String ACCESS_PERMISSION = + "{\"owner\":\"owner\",\"readers\":[\"reader1\",\"reader2\"],\"writers\":[\"writer1\",\"writer2\"]}"; + + @Test + public void testValidInputs() { + boolean isSystemStore = false; + + CreateNewStoreRequest request = new CreateNewStoreRequest( + CLUSTER_NAME, + STORE_NAME, + OWNER, + KEY_SCHEMA, + VALUE_SCHEMA, + ACCESS_PERMISSION, + isSystemStore); + assertNotNull(request); + assertEquals(request.getClusterName(), CLUSTER_NAME); + assertEquals(request.getStoreName(), STORE_NAME); + assertEquals(request.getOwner(), OWNER); + assertEquals(request.getKeySchema(), KEY_SCHEMA); + assertEquals(request.getValueSchema(), VALUE_SCHEMA); + assertEquals(request.getAccessPermissions(), ACCESS_PERMISSION); + assertEquals(request.isSystemStore(), isSystemStore); + } + + @Test + public void testInvalidClusterName() { + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest(null, STORE_NAME, OWNER, KEY_SCHEMA, VALUE_SCHEMA, null, false)); + assertEquals("The request is missing the cluster_name, which is a mandatory field.", exception1.getMessage()); + + IllegalArgumentException exception2 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest("", STORE_NAME, OWNER, KEY_SCHEMA, VALUE_SCHEMA, null, false)); + assertEquals("The request is missing the cluster_name, which is a mandatory field.", exception2.getMessage()); + } + + @Test + public void testInvalidStoreName() { + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest(CLUSTER_NAME, null, OWNER, KEY_SCHEMA, VALUE_SCHEMA, null, false)); + assertEquals("The request is missing the store_name, which is a mandatory field.", exception1.getMessage()); + + IllegalArgumentException exception2 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest(CLUSTER_NAME, "", OWNER, KEY_SCHEMA, VALUE_SCHEMA, null, false)); + assertEquals("The request is missing the store_name, which is a mandatory field.", exception2.getMessage()); + } + + @Test + public void testInvalidKeySchema() { + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest(CLUSTER_NAME, STORE_NAME, OWNER, null, VALUE_SCHEMA, null, false)); + assertEquals("The request is missing the Key schema, which is a mandatory field.", exception1.getMessage()); + + IllegalArgumentException exception2 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest(CLUSTER_NAME, STORE_NAME, OWNER, "", VALUE_SCHEMA, null, false)); + assertEquals("The request is missing the Key schema, which is a mandatory field.", exception2.getMessage()); + } + + @Test + public void testInvalidValueSchema() { + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest(CLUSTER_NAME, STORE_NAME, OWNER, KEY_SCHEMA, null, null, false)); + assertEquals("The request is missing the Value schema, which is a mandatory field.", exception1.getMessage()); + + IllegalArgumentException exception2 = expectThrows( + IllegalArgumentException.class, + () -> new CreateNewStoreRequest(CLUSTER_NAME, STORE_NAME, OWNER, KEY_SCHEMA, "", null, false)); + assertEquals("The request is missing the Value schema, which is a mandatory field.", exception2.getMessage()); + } + + @Test + public void testNullOwnerAndAccessPermissions() { + CreateNewStoreRequest request = + new CreateNewStoreRequest(CLUSTER_NAME, STORE_NAME, null, KEY_SCHEMA, VALUE_SCHEMA, null, false); + + assertEquals(CLUSTER_NAME, request.getClusterName()); + assertEquals(STORE_NAME, request.getStoreName()); + assertEquals(request.getOwner(), CreateNewStoreRequest.DEFAULT_STORE_OWNER); + assertEquals(KEY_SCHEMA, request.getKeySchema()); + assertEquals(VALUE_SCHEMA, request.getValueSchema()); + assertNull(request.getAccessPermissions()); + assertFalse(request.isSystemStore()); + } + + @Test + public void testIsSystemStore() { + CreateNewStoreRequest request1 = + new CreateNewStoreRequest(CLUSTER_NAME, STORE_NAME, OWNER, KEY_SCHEMA, VALUE_SCHEMA, null, true); + assertTrue(request1.isSystemStore()); + + CreateNewStoreRequest request2 = + new CreateNewStoreRequest(CLUSTER_NAME, STORE_NAME, OWNER, KEY_SCHEMA, VALUE_SCHEMA, null, false); + assertFalse(request2.isSystemStore()); + } +} diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/transport/GrpcRequestResponseConverterTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/transport/GrpcRequestResponseConverterTest.java new file mode 100644 index 00000000000..ebd47848593 --- /dev/null +++ b/internal/venice-common/src/test/java/com/linkedin/venice/controllerapi/transport/GrpcRequestResponseConverterTest.java @@ -0,0 +1,166 @@ +package com.linkedin.venice.controllerapi.transport; + +import static org.mockito.Mockito.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +import com.google.protobuf.Any; +import com.google.rpc.Code; +import com.google.rpc.Status; +import com.linkedin.venice.client.exceptions.VeniceClientException; +import com.linkedin.venice.controllerapi.ControllerResponse; +import com.linkedin.venice.controllerapi.request.ControllerRequest; +import com.linkedin.venice.protocols.controller.ClusterStoreGrpcInfo; +import com.linkedin.venice.protocols.controller.ControllerGrpcErrorType; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcErrorInfo; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.StatusProto; +import io.grpc.stub.StreamObserver; +import org.testng.annotations.Test; + + +public class GrpcRequestResponseConverterTest { + private static final String TEST_CLUSTER = "testCluster"; + private static final String TEST_STORE = "testStore"; + + @Test + public void testGetClusterStoreGrpcInfoFromResponse() { + // Test with all fields set + ControllerResponse response = mock(ControllerResponse.class); + when(response.getCluster()).thenReturn("testCluster"); + when(response.getName()).thenReturn("testStore"); + ClusterStoreGrpcInfo grpcInfo = GrpcRequestResponseConverter.getClusterStoreGrpcInfo(response); + assertEquals(grpcInfo.getClusterName(), "testCluster"); + assertEquals(grpcInfo.getStoreName(), "testStore"); + + // Test with null fields + when(response.getCluster()).thenReturn(null); + when(response.getName()).thenReturn(null); + grpcInfo = GrpcRequestResponseConverter.getClusterStoreGrpcInfo(response); + assertEquals(grpcInfo.getClusterName(), ""); + assertEquals(grpcInfo.getStoreName(), ""); + } + + @Test + public void testGetClusterStoreGrpcInfoFromRequest() { + // Test with all fields set + ControllerRequest request = mock(ControllerRequest.class); + when(request.getClusterName()).thenReturn("testCluster"); + when(request.getStoreName()).thenReturn("testStore"); + ClusterStoreGrpcInfo grpcInfo = GrpcRequestResponseConverter.getClusterStoreGrpcInfo(request); + assertEquals(grpcInfo.getClusterName(), "testCluster"); + assertEquals(grpcInfo.getStoreName(), "testStore"); + + // Test with null fields + when(request.getClusterName()).thenReturn(null); + when(request.getStoreName()).thenReturn(null); + grpcInfo = GrpcRequestResponseConverter.getClusterStoreGrpcInfo(request); + assertEquals(grpcInfo.getClusterName(), ""); + assertEquals(grpcInfo.getStoreName(), ""); + } + + @Test + public void testSendErrorResponse() { + StreamObserver responseObserver = mock(StreamObserver.class); + + Exception e = new Exception("Test error message"); + Code errorCode = Code.INVALID_ARGUMENT; + ControllerGrpcErrorType errorType = ControllerGrpcErrorType.BAD_REQUEST; + + GrpcRequestResponseConverter.sendErrorResponse( + io.grpc.Status.Code.INVALID_ARGUMENT, + ControllerGrpcErrorType.BAD_REQUEST, + e, + TEST_CLUSTER, + TEST_STORE, + responseObserver); + + verify(responseObserver, times(1)).onError(argThat(statusRuntimeException -> { + com.google.rpc.Status status = StatusProto.fromThrowable((StatusRuntimeException) statusRuntimeException); + + VeniceControllerGrpcErrorInfo errorInfo = null; + for (Any detail: status.getDetailsList()) { + if (detail.is(VeniceControllerGrpcErrorInfo.class)) { + try { + errorInfo = detail.unpack(VeniceControllerGrpcErrorInfo.class); + break; + } catch (Exception ignored) { + } + } + } + + assertNotNull(errorInfo); + assertEquals(errorInfo.getErrorType(), errorType); + assertEquals(errorInfo.getErrorMessage(), "Test error message"); + assertEquals(errorInfo.getClusterName(), "testCluster"); + assertEquals(errorInfo.getStoreName(), "testStore"); + assertEquals(status.getCode(), errorCode.getNumber()); + + return true; + })); + } + + @Test + public void testParseControllerGrpcError() { + // Create a valid VeniceControllerGrpcErrorInfo + VeniceControllerGrpcErrorInfo errorInfo = VeniceControllerGrpcErrorInfo.newBuilder() + .setErrorType(ControllerGrpcErrorType.BAD_REQUEST) + .setErrorMessage("Invalid input") + .setStatusCode(Code.INVALID_ARGUMENT.getNumber()) + .build(); + + // Wrap in a com.google.rpc.Status + Status rpcStatus = + Status.newBuilder().setCode(Code.INVALID_ARGUMENT.getNumber()).addDetails(Any.pack(errorInfo)).build(); + + // Convert to StatusRuntimeException + StatusRuntimeException exception = StatusProto.toStatusRuntimeException(rpcStatus); + + // Parse the error + VeniceControllerGrpcErrorInfo parsedError = GrpcRequestResponseConverter.parseControllerGrpcError(exception); + + // Assert the parsed error matches the original + assertEquals(parsedError.getErrorType(), ControllerGrpcErrorType.BAD_REQUEST); + assertEquals(parsedError.getErrorMessage(), "Invalid input"); + assertEquals(parsedError.getStatusCode(), Code.INVALID_ARGUMENT.getNumber()); + } + + @Test + public void testParseControllerGrpcErrorWithNoDetails() { + // Create an exception with no details + Status rpcStatus = Status.newBuilder().setCode(Code.UNKNOWN.getNumber()).build(); + StatusRuntimeException exception = StatusProto.toStatusRuntimeException(rpcStatus); + + VeniceClientException thrownException = expectThrows( + VeniceClientException.class, + () -> GrpcRequestResponseConverter.parseControllerGrpcError(exception)); + + assertEquals(thrownException.getMessage(), "An unknown gRPC error occurred. Error code: UNKNOWN"); + } + + @Test + public void testParseControllerGrpcErrorWithUnpackFailure() { + // Create a corrupted detail + Any corruptedDetail = Any.newBuilder() + .setTypeUrl("type.googleapis.com/" + VeniceControllerGrpcErrorInfo.getDescriptor().getFullName()) + .setValue(com.google.protobuf.ByteString.copyFromUtf8("corrupted data")) + .build(); + + Status rpcStatus = + Status.newBuilder().setCode(Code.INVALID_ARGUMENT.getNumber()).addDetails(corruptedDetail).build(); + + StatusRuntimeException exception = StatusProto.toStatusRuntimeException(rpcStatus); + + VeniceClientException thrownException = expectThrows(VeniceClientException.class, () -> { + GrpcRequestResponseConverter.parseControllerGrpcError(exception); + }); + + assertTrue(thrownException.getMessage().contains("Failed to unpack error details")); + } +} diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/grpc/GrpcUtilsTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/grpc/GrpcUtilsTest.java index d89a2d36751..6059d27e94d 100644 --- a/internal/venice-common/src/test/java/com/linkedin/venice/grpc/GrpcUtilsTest.java +++ b/internal/venice-common/src/test/java/com/linkedin/venice/grpc/GrpcUtilsTest.java @@ -1,14 +1,29 @@ package com.linkedin.venice.grpc; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; import com.linkedin.venice.acl.handler.AccessResult; +import com.linkedin.venice.client.exceptions.VeniceClientException; +import com.linkedin.venice.exceptions.VeniceException; import com.linkedin.venice.security.SSLFactory; import com.linkedin.venice.utils.SslUtils; +import io.grpc.Attributes; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ServerCall; import io.grpc.Status; +import io.grpc.TlsChannelCredentials; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; import javax.net.ssl.TrustManager; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; @@ -70,4 +85,124 @@ public void testHttpResponseStatusToGrpcStatus() { grpcStatus.getDescription(), "Mismatch in error description for the mapped grpc status"); } + + @Test + public void testBuildChannelCredentials() { + // Case 1: sslFactory is null, expect InsecureChannelCredentials + ChannelCredentials credentials = GrpcUtils.buildChannelCredentials(null); + assertTrue( + credentials instanceof InsecureChannelCredentials, + "Expected InsecureChannelCredentials when sslFactory is null"); + + // Case 2: Valid sslFactory, expect TlsChannelCredentials + SSLFactory validSslFactory = SslUtils.getVeniceLocalSslFactory(); + credentials = GrpcUtils.buildChannelCredentials(validSslFactory); + assertTrue( + credentials instanceof TlsChannelCredentials, + "Expected TlsChannelCredentials when sslFactory is provided"); + + // Case 3: SSLFactory throws an exception when initializing credentials + SSLFactory faultySslFactory = mock(SSLFactory.class); + Exception exception = + expectThrows(VeniceClientException.class, () -> GrpcUtils.buildChannelCredentials(faultySslFactory)); + assertEquals( + exception.getMessage(), + "Failed to initialize SSL channel credentials for Venice gRPC Transport Client"); + } + + @Test + public void testExtractGrpcClientCertWithValidCertificate() throws SSLPeerUnverifiedException { + // Mock SSLSession and Certificate + SSLSession sslSession = mock(SSLSession.class); + X509Certificate x509Certificate = mock(X509Certificate.class); + + // Mock the ServerCall and its attributes + Attributes attributes = Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + ServerCall call = mock(ServerCall.class); + when(call.getAttributes()).thenReturn(attributes); + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[] { x509Certificate }); + + // Extract the certificate + X509Certificate extractedCertificate = GrpcUtils.extractGrpcClientCert(call); + + // Verify the returned certificate + assertEquals(extractedCertificate, x509Certificate); + } + + @Test + public void testExtractGrpcClientCertWithNullSslSession() { + // Mock the ServerCall with null SSLSession + Attributes attributes = Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_SSL_SESSION, null).build(); + ServerCall call = mock(ServerCall.class); + when(call.getAttributes()).thenReturn(attributes); + when(call.getAuthority()).thenReturn("test-authority"); + + // Expect a VeniceException + VeniceException thrownException = expectThrows(VeniceException.class, () -> { + GrpcUtils.extractGrpcClientCert(call); + }); + + // Verify the exception message + assertEquals(thrownException.getMessage(), "Failed to obtain SSL session"); + } + + @Test + public void testExtractGrpcClientCertWithPeerCertificateNotX509() throws SSLPeerUnverifiedException { + // Mock SSLSession and Certificate + SSLSession sslSession = mock(SSLSession.class); + Certificate nonX509Certificate = mock(Certificate.class); + + // Mock the ServerCall and its attributes + Attributes attributes = Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + ServerCall call = mock(ServerCall.class); + when(call.getAttributes()).thenReturn(attributes); + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[] { nonX509Certificate }); + + // Expect IllegalArgumentException + IllegalArgumentException thrownException = + expectThrows(IllegalArgumentException.class, () -> GrpcUtils.extractGrpcClientCert(call)); + + // Verify the exception message + assertTrue( + thrownException.getMessage() + .contains("Only certificates of type java.security.cert.X509Certificate are supported")); + } + + @Test + public void testExtractGrpcClientCertWithNullPeerCertificates() throws SSLPeerUnverifiedException { + // Mock SSLSession with null peer certificates + SSLSession sslSession = mock(SSLSession.class); + when(sslSession.getPeerCertificates()).thenReturn(null); + + // Mock the ServerCall and its attributes + Attributes attributes = Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + ServerCall call = mock(ServerCall.class); + when(call.getAttributes()).thenReturn(attributes); + + // Expect NullPointerException or VeniceException + NullPointerException thrownException = + expectThrows(NullPointerException.class, () -> GrpcUtils.extractGrpcClientCert(call)); + + // Verify the exception is thrown + assertNotNull(thrownException); + } + + @Test + public void testExtractGrpcClientCertWithEmptyPeerCertificates() throws SSLPeerUnverifiedException { + // Mock SSLSession with empty peer certificates + SSLSession sslSession = mock(SSLSession.class); + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[] {}); + + // Mock the ServerCall and its attributes + Attributes attributes = Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + ServerCall call = mock(ServerCall.class); + when(call.getAttributes()).thenReturn(attributes); + + // Expect IndexOutOfBoundsException + IndexOutOfBoundsException thrownException = + expectThrows(IndexOutOfBoundsException.class, () -> GrpcUtils.extractGrpcClientCert(call)); + + // Verify the exception is thrown + assertNotNull(thrownException); + } } diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerConfigTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerConfigTest.java new file mode 100644 index 00000000000..82ccaa6ea0e --- /dev/null +++ b/internal/venice-common/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerConfigTest.java @@ -0,0 +1,112 @@ +package com.linkedin.venice.grpc; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +import com.linkedin.venice.security.SSLFactory; +import io.grpc.BindableService; +import io.grpc.ServerCredentials; +import io.grpc.ServerInterceptor; +import java.util.Collections; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import org.testng.annotations.Test; + + +public class VeniceGrpcServerConfigTest { + @Test + public void testBuilderWithAllFieldsSet() { + ServerCredentials credentials = mock(ServerCredentials.class); + BindableService service = mock(BindableService.class); + ServerInterceptor interceptor = mock(ServerInterceptor.class); + SSLFactory sslFactory = mock(SSLFactory.class); + Executor executor = Executors.newSingleThreadExecutor(); + + VeniceGrpcServerConfig config = new VeniceGrpcServerConfig.Builder().setPort(8080) + .setCredentials(credentials) + .setService(service) + .setInterceptor(interceptor) + .setSslFactory(sslFactory) + .setExecutor(executor) + .build(); + + assertEquals(config.getPort(), 8080); + assertEquals(config.getCredentials(), credentials); + assertEquals(config.getService(), service); + assertEquals(config.getInterceptors(), Collections.singletonList(interceptor)); + assertEquals(config.getSslFactory(), sslFactory); + assertEquals(config.getExecutor(), executor); + } + + @Test + public void testBuilderWithDefaultInterceptors() { + BindableService service = mock(BindableService.class); + + VeniceGrpcServerConfig config = + new VeniceGrpcServerConfig.Builder().setPort(8080).setService(service).setNumThreads(2).build(); + + assertTrue(config.getInterceptors().isEmpty()); + assertNotNull(config.getExecutor()); + } + + @Test + public void testBuilderWithDefaultExecutor() { + BindableService service = mock(BindableService.class); + + VeniceGrpcServerConfig config = + new VeniceGrpcServerConfig.Builder().setPort(8080).setService(service).setNumThreads(4).build(); + + assertNotNull(config.getExecutor()); + assertEquals(((ThreadPoolExecutor) config.getExecutor()).getCorePoolSize(), 4); + } + + @Test + public void testToStringMethod() { + BindableService service = mock(BindableService.class); + when(service.toString()).thenReturn("MockService"); + + VeniceGrpcServerConfig config = + new VeniceGrpcServerConfig.Builder().setPort(9090).setService(service).setNumThreads(2).build(); + + String expectedString = "VeniceGrpcServerConfig{port=9090, service=MockService}"; + assertEquals(config.toString(), expectedString); + } + + @Test + public void testBuilderValidationWithMissingPort() { + BindableService service = mock(BindableService.class); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> { + new VeniceGrpcServerConfig.Builder().setService(service).setNumThreads(2).build(); + }); + + assertEquals(exception.getMessage(), "Port value is required to create the gRPC server but was not provided."); + } + + @Test + public void testBuilderValidationWithMissingService() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new VeniceGrpcServerConfig.Builder().setPort(8080).setNumThreads(2).build()); + + assertEquals(exception.getMessage(), "A non-null gRPC service instance is required to create the server."); + } + + @Test + public void testBuilderValidationWithInvalidThreadsAndExecutor() { + BindableService service = mock(BindableService.class); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new VeniceGrpcServerConfig.Builder().setPort(8080).setService(service).build()); + + assertEquals( + exception.getMessage(), + "gRPC server creation requires a valid number of threads (numThreads > 0) or a non-null executor."); + } +} diff --git a/services/venice-server/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerTest.java similarity index 70% rename from services/venice-server/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerTest.java rename to internal/venice-common/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerTest.java index 47a6ef204b6..072620c9bef 100644 --- a/services/venice-server/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerTest.java +++ b/internal/venice-common/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerTest.java @@ -1,34 +1,35 @@ package com.linkedin.venice.grpc; -import static org.mockito.Mockito.mock; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; import com.linkedin.venice.exceptions.VeniceException; -import com.linkedin.venice.listener.grpc.VeniceReadServiceImpl; -import com.linkedin.venice.listener.grpc.handlers.VeniceServerGrpcRequestProcessor; +import com.linkedin.venice.protocols.VeniceEchoRequest; +import com.linkedin.venice.protocols.VeniceEchoResponse; +import com.linkedin.venice.protocols.VeniceEchoServiceGrpc; import com.linkedin.venice.security.SSLFactory; import com.linkedin.venice.utils.SslUtils; import com.linkedin.venice.utils.TestUtils; import io.grpc.InsecureServerCredentials; import io.grpc.ServerCredentials; import io.grpc.TlsServerCredentials; +import io.grpc.stub.StreamObserver; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; public class VeniceGrpcServerTest { + private static final int NUM_THREADS = 2; private VeniceGrpcServer grpcServer; private VeniceGrpcServerConfig.Builder serverConfig; - private VeniceServerGrpcRequestProcessor grpcRequestProcessor; @BeforeMethod void setUp() { - grpcRequestProcessor = mock(VeniceServerGrpcRequestProcessor.class); serverConfig = new VeniceGrpcServerConfig.Builder().setPort(TestUtils.getFreePort()) - .setNumThreads(10) - .setService(new VeniceReadServiceImpl(grpcRequestProcessor)); + .setNumThreads(NUM_THREADS) + .setService(new VeniceEchoServiceImpl()); } @Test @@ -36,10 +37,12 @@ void startServerSuccessfully() { grpcServer = new VeniceGrpcServer(serverConfig.build()); grpcServer.start(); + assertNotNull(grpcServer.getServer()); + assertTrue(grpcServer.isRunning()); assertFalse(grpcServer.isTerminated()); grpcServer.stop(); - assertTrue(grpcServer.isShutdown()); + assertFalse(grpcServer.isRunning()); } @Test @@ -65,7 +68,7 @@ void testServerShutdown() throws InterruptedException { Thread.sleep(500); grpcServer.stop(); - assertTrue(grpcServer.isShutdown()); + assertFalse(grpcServer.isRunning()); Thread.sleep(500); @@ -89,4 +92,13 @@ void testServerWithSSL() { assertFalse(serverCredentials instanceof TlsServerCredentials); assertTrue(serverCredentials instanceof InsecureServerCredentials); } + + public static class VeniceEchoServiceImpl extends VeniceEchoServiceGrpc.VeniceEchoServiceImplBase { + @Override + public void echo(VeniceEchoRequest grpcRequest, StreamObserver responseObserver) { + VeniceEchoResponse grpcResponse = VeniceEchoResponse.newBuilder().setMessage(grpcRequest.getMessage()).build(); + responseObserver.onNext(grpcResponse); + responseObserver.onCompleted(); + } + } } diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/meta/TestInstance.java b/internal/venice-common/src/test/java/com/linkedin/venice/meta/TestInstance.java index 47c27598430..6395054ec15 100644 --- a/internal/venice-common/src/test/java/com/linkedin/venice/meta/TestInstance.java +++ b/internal/venice-common/src/test/java/com/linkedin/venice/meta/TestInstance.java @@ -23,4 +23,17 @@ public void parsesNodeId() { Assert.assertEquals(host.getHost(), "localhost"); Assert.assertEquals(host.getPort(), 1234); } + + @Test + public void testInstanceWithGrpcAddress() { + Instance nonGrpcInstance = new Instance("localhost_1234", "localhost", 1234); + Assert.assertEquals(nonGrpcInstance.getGrpcSslPort(), -1); + Assert.assertEquals(nonGrpcInstance.getGrpcPort(), -1); + + Instance grpcInstance = new Instance("localhost_1234", "localhost", 1234, 1235, 1236); + Assert.assertEquals(grpcInstance.getGrpcPort(), 1235); + Assert.assertEquals(grpcInstance.getGrpcSslPort(), 1236); + Assert.assertEquals(grpcInstance.getGrpcUrl(), "localhost:1235"); + Assert.assertEquals(grpcInstance.getGrpcSslUrl(), "localhost:1236"); + } } diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/controller/server/TestAdminSparkWithMocks.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/controller/server/TestAdminSparkWithMocks.java index 925625ab4c5..166d393a08f 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/controller/server/TestAdminSparkWithMocks.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/controller/server/TestAdminSparkWithMocks.java @@ -6,9 +6,11 @@ import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ControllerRequestHandlerDependencies; import com.linkedin.venice.controller.ParentControllerRegionState; -import com.linkedin.venice.controller.VeniceHelixAdmin; import com.linkedin.venice.controllerapi.ControllerApiConstants; import com.linkedin.venice.controllerapi.ControllerRoute; import com.linkedin.venice.controllerapi.VersionCreationResponse; @@ -43,6 +45,7 @@ import org.apache.http.message.BasicNameValuePair; import org.mockito.Mockito; import org.testng.Assert; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -52,10 +55,20 @@ * verifying any state changes that would be triggered by the admin. */ public class TestAdminSparkWithMocks { + private VeniceControllerRequestHandler requestHandler; + private Admin admin; + + @BeforeMethod(alwaysRun = true) + public void setUp() { + admin = Mockito.mock(Admin.class); + ControllerRequestHandlerDependencies dependencies = mock(ControllerRequestHandlerDependencies.class); + doReturn(admin).when(dependencies).getAdmin(); + requestHandler = new VeniceControllerRequestHandler(dependencies); + } + @Test public void testGetRealTimeTopicUsesAdmin() throws Exception { // setup server with mock admin, note returns topic "store_rt" - VeniceHelixAdmin admin = Mockito.mock(VeniceHelixAdmin.class); Store mockStore = new ZKStore( "store", "owner", @@ -81,8 +94,11 @@ public void testGetRealTimeTopicUsesAdmin() throws Exception { doReturn("store_rt").when(admin).getRealTimeTopic(anyString(), any(Store.class)); // Add a banned route not relevant to the test just to make sure theres coverage for unbanned routes still be // accessible - AdminSparkServer server = - ServiceFactory.getMockAdminSparkServer(admin, "clustername", Arrays.asList(ControllerRoute.ADD_DERIVED_SCHEMA)); + AdminSparkServer server = ServiceFactory.getMockAdminSparkServer( + admin, + "clustername", + Arrays.asList(ControllerRoute.ADD_DERIVED_SCHEMA), + requestHandler); int port = server.getPort(); // build request @@ -117,7 +133,6 @@ public void testGetRealTimeTopicUsesAdmin() throws Exception { @Test public void testBannedRoutesAreRejected() throws Exception { // setup server with mock admin, note returns topic "store_rt" - VeniceHelixAdmin admin = Mockito.mock(VeniceHelixAdmin.class); Store mockStore = new ZKStore( "store", "owner", @@ -141,8 +156,8 @@ public void testBannedRoutesAreRejected() throws Exception { doReturn("kafka-bootstrap").when(admin).getKafkaBootstrapServers(anyBoolean()); doReturn("store_rt").when(admin).getRealTimeTopic(anyString(), anyString()); doReturn("store_rt").when(admin).getRealTimeTopic(anyString(), any(Store.class)); - AdminSparkServer server = - ServiceFactory.getMockAdminSparkServer(admin, "clustername", Arrays.asList(ControllerRoute.REQUEST_TOPIC)); + AdminSparkServer server = ServiceFactory + .getMockAdminSparkServer(admin, "clustername", Arrays.asList(ControllerRoute.REQUEST_TOPIC), requestHandler); int port = server.getPort(); // build request @@ -189,7 +204,6 @@ public void testAAIncrementalPushRTSourceRegion(boolean sourceGridFabricPresent, Optional optionalemergencySourceRegion = Optional.empty(); Optional optionalSourceGridSourceFabric = Optional.empty(); - VeniceHelixAdmin admin = Mockito.mock(VeniceHelixAdmin.class); Store mockStore = new ZKStore( storeName, "owner", @@ -265,8 +279,11 @@ public void testAAIncrementalPushRTSourceRegion(boolean sourceGridFabricPresent, // Add a banned route not relevant to the test just to make sure theres coverage for unbanned routes still be // accessible - AdminSparkServer server = - ServiceFactory.getMockAdminSparkServer(admin, "clustername", Arrays.asList(ControllerRoute.ADD_DERIVED_SCHEMA)); + AdminSparkServer server = ServiceFactory.getMockAdminSparkServer( + admin, + "clustername", + Arrays.asList(ControllerRoute.ADD_DERIVED_SCHEMA), + requestHandler); int port = server.getPort(); final HttpPost post = new HttpPost("http://localhost:" + port + ControllerRoute.REQUEST_TOPIC.getPath()); post.setEntity(new UrlEncodedFormEntity(params)); @@ -303,7 +320,6 @@ public void testAAIncrementalPushRTSourceRegion(boolean sourceGridFabricPresent, public void testSamzaReplicationPolicyMode(boolean samzaPolicy, boolean storePolicy, boolean aaEnabled) throws Exception { // setup server with mock admin, note returns topic "store_rt" - VeniceHelixAdmin admin = Mockito.mock(VeniceHelixAdmin.class); Store mockStore = new ZKStore( "store", "owner", @@ -344,8 +360,11 @@ public void testSamzaReplicationPolicyMode(boolean samzaPolicy, boolean storePol // Add a banned route not relevant to the test just to make sure theres coverage for unbanned routes still be // accessible - AdminSparkServer server = - ServiceFactory.getMockAdminSparkServer(admin, "clustername", Arrays.asList(ControllerRoute.ADD_DERIVED_SCHEMA)); + AdminSparkServer server = ServiceFactory.getMockAdminSparkServer( + admin, + "clustername", + Arrays.asList(ControllerRoute.ADD_DERIVED_SCHEMA), + requestHandler); int port = server.getPort(); // build request diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestControllerGrpcEndpoints.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestControllerGrpcEndpoints.java new file mode 100644 index 00000000000..f5d561dfcde --- /dev/null +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestControllerGrpcEndpoints.java @@ -0,0 +1,100 @@ +package com.linkedin.venice.endToEnd; + +import static com.linkedin.venice.ConfigKeys.CONTROLLER_GRPC_SERVER_ENABLED; +import static com.linkedin.venice.integration.utils.VeniceClusterWrapper.DEFAULT_KEY_SCHEMA; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +import com.linkedin.venice.controllerapi.StoreResponse; +import com.linkedin.venice.integration.utils.ServiceFactory; +import com.linkedin.venice.integration.utils.VeniceClusterCreateOptions; +import com.linkedin.venice.integration.utils.VeniceClusterWrapper; +import com.linkedin.venice.protocols.controller.ClusterStoreGrpcInfo; +import com.linkedin.venice.protocols.controller.CreateStoreGrpcRequest; +import com.linkedin.venice.protocols.controller.CreateStoreGrpcResponse; +import com.linkedin.venice.protocols.controller.DiscoverClusterGrpcRequest; +import com.linkedin.venice.protocols.controller.DiscoverClusterGrpcResponse; +import com.linkedin.venice.protocols.controller.LeaderControllerGrpcRequest; +import com.linkedin.venice.protocols.controller.LeaderControllerGrpcResponse; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc.VeniceControllerGrpcServiceBlockingStub; +import com.linkedin.venice.utils.TestUtils; +import com.linkedin.venice.utils.Utils; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import java.util.Properties; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + + +public class TestControllerGrpcEndpoints { + private VeniceClusterWrapper veniceCluster; + + @BeforeClass(alwaysRun = true) + public void setUp() { + Properties properties = new Properties(); + properties.put(CONTROLLER_GRPC_SERVER_ENABLED, true); + VeniceClusterCreateOptions options = new VeniceClusterCreateOptions.Builder().numberOfControllers(1) + .numberOfRouters(1) + .numberOfServers(1) + .extraProperties(properties) + .build(); + veniceCluster = ServiceFactory.getVeniceCluster(options); + } + + @AfterClass(alwaysRun = true) + public void tearDown() { + Utils.closeQuietlyWithErrorLogged(veniceCluster); + } + + @Test + public void testGrpcEndpointsWithGrpcClient() { + String storeName = Utils.getUniqueString("test_grpc_store"); + String controllerGrpcUrl = veniceCluster.getLeaderVeniceController().getControllerGrpcUrl(); + ManagedChannel channel = Grpc.newChannelBuilder(controllerGrpcUrl, InsecureChannelCredentials.create()).build(); + VeniceControllerGrpcServiceBlockingStub blockingStub = VeniceControllerGrpcServiceGrpc.newBlockingStub(channel); + + // Test 1: getLeaderController + LeaderControllerGrpcResponse grpcResponse = blockingStub.getLeaderController( + LeaderControllerGrpcRequest.newBuilder().setClusterName(veniceCluster.getClusterName()).build()); + assertEquals(grpcResponse.getHttpUrl(), veniceCluster.getLeaderVeniceController().getControllerUrl()); + assertEquals(grpcResponse.getGrpcUrl(), veniceCluster.getLeaderVeniceController().getControllerGrpcUrl()); + assertEquals( + grpcResponse.getSecureGrpcUrl(), + veniceCluster.getLeaderVeniceController().getControllerSecureGrpcUrl()); + + // Test 2: createStore + CreateStoreGrpcRequest createStoreGrpcRequest = CreateStoreGrpcRequest.newBuilder() + .setClusterStoreInfo( + ClusterStoreGrpcInfo.newBuilder() + .setClusterName(veniceCluster.getClusterName()) + .setStoreName(storeName) + .build()) + .setOwner("owner") + .setKeySchema(DEFAULT_KEY_SCHEMA) + .setValueSchema("\"string\"") + .build(); + + CreateStoreGrpcResponse response = blockingStub.createStore(createStoreGrpcRequest); + assertNotNull(response, "Response should not be null"); + assertNotNull(response.getClusterStoreInfo(), "ClusterStoreInfo should not be null"); + assertEquals(response.getClusterStoreInfo().getClusterName(), veniceCluster.getClusterName()); + assertEquals(response.getClusterStoreInfo().getStoreName(), storeName); + + veniceCluster.useControllerClient(controllerClient -> { + StoreResponse storeResponse = TestUtils.assertCommand(controllerClient.getStore(storeName)); + assertNotNull(storeResponse.getStore(), "Store should not be null"); + }); + + // Test 3: discover cluster + DiscoverClusterGrpcRequest discoverClusterGrpcRequest = + DiscoverClusterGrpcRequest.newBuilder().setStoreName(storeName).build(); + DiscoverClusterGrpcResponse discoverClusterGrpcResponse = + blockingStub.discoverClusterForStore(discoverClusterGrpcRequest); + assertNotNull(discoverClusterGrpcResponse, "Response should not be null"); + assertEquals(discoverClusterGrpcResponse.getStoreName(), storeName); + assertEquals(discoverClusterGrpcResponse.getClusterName(), veniceCluster.getClusterName()); + } +} diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ServiceFactory.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ServiceFactory.java index e412b1af2bb..f73a96ad066 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ServiceFactory.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ServiceFactory.java @@ -14,6 +14,7 @@ import com.linkedin.venice.client.store.ClientConfig; import com.linkedin.venice.controller.Admin; import com.linkedin.venice.controller.server.AdminSparkServer; +import com.linkedin.venice.controller.server.VeniceControllerRequestHandler; import com.linkedin.venice.controllerapi.ControllerRoute; import com.linkedin.venice.exceptions.VeniceException; import com.linkedin.venice.pubsub.PubSubClientsFactory; @@ -152,7 +153,8 @@ public static VeniceControllerWrapper getVeniceController(VeniceControllerCreate public static AdminSparkServer getMockAdminSparkServer( Admin admin, String cluster, - List bannedRoutes) { + List bannedRoutes, + VeniceControllerRequestHandler requestHandler) { return getService("MockAdminSparkServer", (serviceName) -> { Set clusters = new HashSet<>(); clusters.add(cluster); @@ -168,7 +170,8 @@ public static AdminSparkServer getMockAdminSparkServer( bannedRoutes, null, false, - new PubSubTopicRepository()); // Change this. + new PubSubTopicRepository(), + requestHandler); // Change this. server.start(); return server; }); diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceClusterWrapper.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceClusterWrapper.java index 8692a77799c..7a268d606d2 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceClusterWrapper.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceClusterWrapper.java @@ -536,6 +536,15 @@ public final synchronized String getAllControllersURLs() { .collect(Collectors.joining(",")); } + public final synchronized String getAllControllersGrpcURLs() { + return veniceControllerWrappers.isEmpty() + ? externalControllerDiscoveryURL + : veniceControllerWrappers.values() + .stream() + .map(VeniceControllerWrapper::getControllerGrpcUrl) + .collect(Collectors.joining(",")); + } + public VeniceControllerWrapper getLeaderVeniceController() { return getLeaderVeniceController(60 * Time.MS_PER_SECOND); } diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceControllerWrapper.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceControllerWrapper.java index 78e51df38b7..61043bb205f 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceControllerWrapper.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/VeniceControllerWrapper.java @@ -15,6 +15,8 @@ import static com.linkedin.venice.ConfigKeys.CLUSTER_TO_SERVER_D2; import static com.linkedin.venice.ConfigKeys.CONCURRENT_INIT_ROUTINES_ENABLED; import static com.linkedin.venice.ConfigKeys.CONTROLLER_ADD_VERSION_VIA_ADMIN_PROTOCOL; +import static com.linkedin.venice.ConfigKeys.CONTROLLER_ADMIN_GRPC_PORT; +import static com.linkedin.venice.ConfigKeys.CONTROLLER_ADMIN_SECURE_GRPC_PORT; import static com.linkedin.venice.ConfigKeys.CONTROLLER_NAME; import static com.linkedin.venice.ConfigKeys.CONTROLLER_PARENT_MODE; import static com.linkedin.venice.ConfigKeys.CONTROLLER_SSL_ENABLED; @@ -109,6 +111,8 @@ public class VeniceControllerWrapper extends ProcessWrapper { private final boolean isParent; private final int port; private final int securePort; + private final int adminGrpcPort; + private final int adminSecureGrpcPort; private final String zkAddress; private final List d2ServerList; private final MetricsRepository metricsRepository; @@ -121,6 +125,8 @@ private VeniceControllerWrapper( VeniceController service, int port, int securePort, + int adminGrpcPort, + int adminSecureGrpcPort, List configs, boolean isParent, List d2ServerList, @@ -132,6 +138,8 @@ private VeniceControllerWrapper( this.isParent = isParent; this.port = port; this.securePort = securePort; + this.adminGrpcPort = adminGrpcPort; + this.adminSecureGrpcPort = adminSecureGrpcPort; this.zkAddress = zkAddress; this.d2ServerList = d2ServerList; this.metricsRepository = metricsRepository; @@ -142,6 +150,8 @@ static StatefulServiceProvider generateService(VeniceCo return (serviceName, dataDirectory) -> { int adminPort = TestUtils.getFreePort(); int adminSecurePort = TestUtils.getFreePort(); + int adminGrpcPort = TestUtils.getFreePort(); + int adminSecureGrpcPort = TestUtils.getFreePort(); List propertiesList = new ArrayList<>(); VeniceProperties extraProps = new VeniceProperties(options.getExtraProperties()); @@ -182,6 +192,8 @@ static StatefulServiceProvider generateService(VeniceCo .put(DEFAULT_REPLICA_FACTOR, options.getReplicationFactor()) .put(ADMIN_PORT, adminPort) .put(ADMIN_SECURE_PORT, adminSecurePort) + .put(CONTROLLER_ADMIN_GRPC_PORT, adminGrpcPort) + .put(CONTROLLER_ADMIN_SECURE_GRPC_PORT, adminSecureGrpcPort) .put(DEFAULT_PARTITION_SIZE, options.getPartitionSize()) .put(DEFAULT_NUMBER_OF_PARTITION, options.getNumberOfPartitions()) .put(DEFAULT_MAX_NUMBER_OF_PARTITIONS, options.getMaxNumberOfPartitions()) @@ -375,6 +387,8 @@ static StatefulServiceProvider generateService(VeniceCo veniceController, adminPort, adminSecurePort, + adminGrpcPort, + adminSecureGrpcPort, propertiesList, options.isParent(), d2ServerList, @@ -401,6 +415,22 @@ public int getSecurePort() { return securePort; } + public int getAdminGrpcPort() { + return adminGrpcPort; + } + + public int getAdminSecureGrpcPort() { + return adminSecureGrpcPort; + } + + public String getControllerGrpcUrl() { + return getHost() + ":" + getAdminGrpcPort(); + } + + public String getControllerSecureGrpcUrl() { + return getHost() + ":" + getAdminSecureGrpcPort(); + } + public String getControllerUrl() { return "http://" + getHost() + ":" + getPort(); } diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ZkServerWrapper.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ZkServerWrapper.java index 312c9295977..3e422aa3ded 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ZkServerWrapper.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/integration/utils/ZkServerWrapper.java @@ -49,7 +49,7 @@ public class ZkServerWrapper extends ProcessWrapper { * The tick time can be low because this Zookeeper instance is intended to be used locally. */ private static final int TICK_TIME = 200; - private static final int MAX_SESSION_TIMEOUT = 10 * Time.MS_PER_SECOND; + private static final int MAX_SESSION_TIMEOUT = 6000 * Time.MS_PER_SECOND; private static final int NUM_CONNECTIONS = 5000; private static final String CLIENT_PORT_PROP = "clientPort"; diff --git a/internal/venice-test-common/src/main/java/com/linkedin/venice/utils/TestUtils.java b/internal/venice-test-common/src/main/java/com/linkedin/venice/utils/TestUtils.java index 48e3d3576c3..5717059c060 100644 --- a/internal/venice-test-common/src/main/java/com/linkedin/venice/utils/TestUtils.java +++ b/internal/venice-test-common/src/main/java/com/linkedin/venice/utils/TestUtils.java @@ -562,6 +562,8 @@ public static Properties getPropertiesForControllerConfig() { properties.put(ConfigKeys.DEFAULT_NUMBER_OF_PARTITION, "1"); properties.put(ConfigKeys.ADMIN_PORT, TestUtils.getFreePort()); properties.put(ConfigKeys.ADMIN_SECURE_PORT, TestUtils.getFreePort()); + properties.put(ConfigKeys.CONTROLLER_ADMIN_GRPC_PORT, TestUtils.getFreePort()); + properties.put(ConfigKeys.CONTROLLER_ADMIN_SECURE_GRPC_PORT, TestUtils.getFreePort()); return properties; } diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/ControllerRequestHandlerDependencies.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/ControllerRequestHandlerDependencies.java new file mode 100644 index 00000000000..43f4ca497b0 --- /dev/null +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/ControllerRequestHandlerDependencies.java @@ -0,0 +1,186 @@ +package com.linkedin.venice.controller; + +import com.linkedin.venice.SSLConfig; +import com.linkedin.venice.acl.DynamicAccessController; +import com.linkedin.venice.controllerapi.ControllerRoute; +import com.linkedin.venice.pubsub.PubSubTopicRepository; +import com.linkedin.venice.utils.VeniceProperties; +import io.tehuti.metrics.MetricsRepository; +import java.util.Collections; +import java.util.List; +import java.util.Set; + + +/** + * Dependencies for VeniceControllerRequestHandler + */ +public class ControllerRequestHandlerDependencies { + private final Admin admin; + private final boolean enforceSSL; + private final boolean sslEnabled; + private final boolean checkReadMethodForKafka; + private final SSLConfig sslConfig; + private final DynamicAccessController accessController; + private final List disabledRoutes; + private final Set clusters; + private final boolean disableParentRequestTopicForStreamPushes; + private final PubSubTopicRepository pubSubTopicRepository; + private final MetricsRepository metricsRepository; + private final VeniceProperties veniceProperties; + + private ControllerRequestHandlerDependencies(Builder builder) { + this.admin = builder.admin; + this.enforceSSL = builder.enforceSSL; + this.sslEnabled = builder.sslEnabled; + this.checkReadMethodForKafka = builder.checkReadMethodForKafka; + this.sslConfig = builder.sslConfig; + this.accessController = builder.accessController; + this.disabledRoutes = builder.disabledRoutes; + this.clusters = builder.clusters; + this.disableParentRequestTopicForStreamPushes = builder.disableParentRequestTopicForStreamPushes; + this.pubSubTopicRepository = builder.pubSubTopicRepository; + this.metricsRepository = builder.metricsRepository; + this.veniceProperties = builder.veniceProperties; + } + + public Admin getAdmin() { + return admin; + } + + public Set getClusters() { + return clusters; + } + + public boolean isEnforceSSL() { + return enforceSSL; + } + + public boolean isSslEnabled() { + return sslEnabled; + } + + public boolean isCheckReadMethodForKafka() { + return checkReadMethodForKafka; + } + + public SSLConfig getSslConfig() { + return sslConfig; + } + + public DynamicAccessController getAccessController() { + return accessController; + } + + public List getDisabledRoutes() { + return disabledRoutes; + } + + public boolean isDisableParentRequestTopicForStreamPushes() { + return disableParentRequestTopicForStreamPushes; + } + + public PubSubTopicRepository getPubSubTopicRepository() { + return pubSubTopicRepository; + } + + public MetricsRepository getMetricsRepository() { + return metricsRepository; + } + + public VeniceProperties getVeniceProperties() { + return veniceProperties; + } + + // Builder class for VeniceControllerRequestHandlerDependencies + public static class Builder { + private Admin admin; + private boolean enforceSSL; + private boolean sslEnabled; + private boolean checkReadMethodForKafka; + private SSLConfig sslConfig; + private DynamicAccessController accessController; + private List disabledRoutes; + private Set clusters; + private boolean disableParentRequestTopicForStreamPushes; + private PubSubTopicRepository pubSubTopicRepository; + private MetricsRepository metricsRepository; + private VeniceProperties veniceProperties; + + public Builder setAdmin(Admin admin) { + this.admin = admin; + return this; + } + + public Builder setClusters(Set clusters) { + this.clusters = clusters; + return this; + } + + public Builder setEnforceSSL(boolean enforceSSL) { + this.enforceSSL = enforceSSL; + return this; + } + + public Builder setSslEnabled(boolean sslEnabled) { + this.sslEnabled = sslEnabled; + return this; + } + + public Builder setCheckReadMethodForKafka(boolean checkReadMethodForKafka) { + this.checkReadMethodForKafka = checkReadMethodForKafka; + return this; + } + + public Builder setSslConfig(SSLConfig sslConfig) { + this.sslConfig = sslConfig; + return this; + } + + public Builder setAccessController(DynamicAccessController accessController) { + this.accessController = accessController; + return this; + } + + public Builder setDisabledRoutes(List disabledRoutes) { + this.disabledRoutes = disabledRoutes; + return this; + } + + public Builder setDisableParentRequestTopicForStreamPushes(boolean disableParentRequestTopicForStreamPushes) { + this.disableParentRequestTopicForStreamPushes = disableParentRequestTopicForStreamPushes; + return this; + } + + public Builder setPubSubTopicRepository(PubSubTopicRepository pubSubTopicRepository) { + this.pubSubTopicRepository = pubSubTopicRepository; + return this; + } + + public Builder setMetricsRepository(MetricsRepository metricsRepository) { + this.metricsRepository = metricsRepository; + return this; + } + + public Builder setVeniceProperties(VeniceProperties veniceProperties) { + this.veniceProperties = veniceProperties; + return this; + } + + private void verifyAndAddDefaults() { + if (admin == null) { + throw new IllegalArgumentException("admin is mandatory dependencies for VeniceControllerRequestHandler"); + } + if (pubSubTopicRepository == null) { + pubSubTopicRepository = new PubSubTopicRepository(); + } + if (disabledRoutes == null) { + disabledRoutes = Collections.emptyList(); + } + } + + public ControllerRequestHandlerDependencies build() { + verifyAndAddDefaults(); + return new ControllerRequestHandlerDependencies(this); + } + } +} diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceController.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceController.java index 2d07dd400bc..b9a6621e787 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceController.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceController.java @@ -11,13 +11,20 @@ import com.linkedin.venice.controller.kafka.TopicCleanupService; import com.linkedin.venice.controller.kafka.TopicCleanupServiceForParentController; import com.linkedin.venice.controller.server.AdminSparkServer; +import com.linkedin.venice.controller.server.VeniceControllerGrpcServiceImpl; +import com.linkedin.venice.controller.server.VeniceControllerRequestHandler; +import com.linkedin.venice.controller.server.grpc.ControllerSslSessionInterceptor; +import com.linkedin.venice.controller.server.grpc.ParentControllerRegionValidationInterceptor; import com.linkedin.venice.controller.stats.TopicCleanupServiceStats; import com.linkedin.venice.controller.supersetschema.SupersetSchemaGenerator; import com.linkedin.venice.controller.systemstore.SystemStoreRepairService; import com.linkedin.venice.d2.D2ClientFactory; import com.linkedin.venice.exceptions.VeniceException; +import com.linkedin.venice.grpc.VeniceGrpcServer; +import com.linkedin.venice.grpc.VeniceGrpcServerConfig; import com.linkedin.venice.pubsub.PubSubClientsFactory; import com.linkedin.venice.pubsub.PubSubTopicRepository; +import com.linkedin.venice.security.SSLFactory; import com.linkedin.venice.serialization.avro.AvroProtocolDefinition; import com.linkedin.venice.service.AbstractVeniceService; import com.linkedin.venice.service.ICProvider; @@ -25,13 +32,18 @@ import com.linkedin.venice.servicediscovery.ServiceDiscoveryAnnouncer; import com.linkedin.venice.system.store.ControllerClientBackedSystemSchemaInitializer; import com.linkedin.venice.utils.PropertyBuilder; +import com.linkedin.venice.utils.SslUtils; import com.linkedin.venice.utils.Utils; import com.linkedin.venice.utils.VeniceProperties; +import com.linkedin.venice.utils.concurrent.BlockingQueueType; +import com.linkedin.venice.utils.concurrent.ThreadPoolFactory; +import io.grpc.ServerInterceptor; import io.tehuti.metrics.MetricsRepository; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.ThreadPoolExecutor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -46,6 +58,8 @@ public class VeniceController { private VeniceControllerService controllerService; private AdminSparkServer adminServer; private AdminSparkServer secureAdminServer; + private VeniceGrpcServer adminGrpcServer; + private VeniceGrpcServer adminSecureGrpcServer; private TopicCleanupService topicCleanupService; private Optional storeBackupVersionCleanupService; @@ -69,6 +83,7 @@ public class VeniceController { private final PubSubTopicRepository pubSubTopicRepository = new PubSubTopicRepository(); private final PubSubClientsFactory pubSubClientsFactory; static final String CONTROLLER_SERVICE_NAME = "venice-controller"; + private ThreadPoolExecutor grpcExecutor = null; /** * Allocates a new {@code VeniceController} object. @@ -149,7 +164,6 @@ private void createServices() { externalSupersetSchemaGenerator, pubSubTopicRepository, pubSubClientsFactory); - adminServer = new AdminSparkServer( // no need to pass the hostname, we are binding to all the addresses multiClusterConfigs.getAdminPort(), @@ -164,7 +178,8 @@ private void createServices() { multiClusterConfigs.getCommonConfig().getJettyConfigOverrides(), // TODO: Builder pattern or just pass the config object here? multiClusterConfigs.getCommonConfig().isDisableParentRequestTopicForStreamPushes(), - pubSubTopicRepository); + pubSubTopicRepository, + new VeniceControllerRequestHandler(buildRequestHandlerDependencies(false))); if (sslEnabled) { /** * SSL enabled AdminSparkServer uses a different port number than the regular service. @@ -181,7 +196,8 @@ private void createServices() { multiClusterConfigs.getDisabledRoutes(), multiClusterConfigs.getCommonConfig().getJettyConfigOverrides(), multiClusterConfigs.getCommonConfig().isDisableParentRequestTopicForStreamPushes(), - pubSubTopicRepository); + pubSubTopicRepository, + new VeniceControllerRequestHandler(buildRequestHandlerDependencies(true))); } storeBackupVersionCleanupService = Optional.empty(); storeGraveyardCleanupService = Optional.empty(); @@ -231,6 +247,68 @@ private void createServices() { } // Run before enabling controller in helix so leadership won't hand back to this controller during schema requests. initializeSystemSchema(controllerService.getVeniceHelixAdmin()); + + // if gRpc server is not enabled, return early + if (multiClusterConfigs.isGrpcServerEnabled()) { + LOGGER.info("gRPC server is enabled in controller. Initializing gRPC server..."); + initializeGrpcServer(); + } + } + + private void initializeGrpcServer() { + ParentControllerRegionValidationInterceptor parentControllerRegionValidationInterceptor = + new ParentControllerRegionValidationInterceptor(controllerService.getVeniceHelixAdmin()); + List interceptors = new ArrayList<>(2); + interceptors.add(parentControllerRegionValidationInterceptor); + + VeniceControllerGrpcServiceImpl grpcService = + new VeniceControllerGrpcServiceImpl(new VeniceControllerRequestHandler(buildRequestHandlerDependencies(false))); + + grpcExecutor = ThreadPoolFactory.createThreadPool( + multiClusterConfigs.getGrpcServerThreadCount(), + "ControllerGrpcServer", + Integer.MAX_VALUE, + BlockingQueueType.LINKED_BLOCKING_QUEUE); + + adminGrpcServer = new VeniceGrpcServer( + new VeniceGrpcServerConfig.Builder().setPort(multiClusterConfigs.getAdminGrpcPort()) + .setService(grpcService) + .setExecutor(grpcExecutor) + .setInterceptors(interceptors) + .build()); + + if (sslEnabled) { + interceptors.add(new ControllerSslSessionInterceptor()); + SSLFactory sslFactory = SslUtils.getSSLFactory( + multiClusterConfigs.getSslConfig().get().getSslProperties(), + multiClusterConfigs.getSslFactoryClassName()); + VeniceControllerGrpcServiceImpl secureGrpcService = new VeniceControllerGrpcServiceImpl( + new VeniceControllerRequestHandler(buildRequestHandlerDependencies(true))); + adminSecureGrpcServer = new VeniceGrpcServer( + new VeniceGrpcServerConfig.Builder().setPort(multiClusterConfigs.getAdminSecureGrpcPort()) + .setService(secureGrpcService) + .setExecutor(grpcExecutor) + .setSslFactory(sslFactory) + .setInterceptors(interceptors) + .build()); + } + } + + private ControllerRequestHandlerDependencies buildRequestHandlerDependencies(boolean secure) { + ControllerRequestHandlerDependencies.Builder builder = + new ControllerRequestHandlerDependencies.Builder().setAdmin(controllerService.getVeniceHelixAdmin()) + .setMetricsRepository(metricsRepository) + .setClusters(multiClusterConfigs.getClusters()) + .setDisabledRoutes(multiClusterConfigs.getDisabledRoutes()) + .setVeniceProperties(multiClusterConfigs.getCommonConfig().getJettyConfigOverrides()) + .setDisableParentRequestTopicForStreamPushes( + multiClusterConfigs.getCommonConfig().isDisableParentRequestTopicForStreamPushes()) + .setPubSubTopicRepository(pubSubTopicRepository) + .setSslConfig(secure ? multiClusterConfigs.getSslConfig().orElse(null) : null) + .setCheckReadMethodForKafka(secure && multiClusterConfigs.adminCheckReadMethodForKafka()) + .setAccessController(secure ? accessController.orElse(null) : null) + .setEnforceSSL(secure || multiClusterConfigs.isControllerEnforceSSLOnly()); + return builder.build(); } /** @@ -255,6 +333,12 @@ public void start() { disabledPartitionEnablerService.ifPresent(AbstractVeniceService::start); // register with service discovery at the end asyncRetryingServiceDiscoveryAnnouncer.register(); + if (adminGrpcServer != null) { + adminGrpcServer.start(); + } + if (adminSecureGrpcServer != null) { + adminSecureGrpcServer.start(); + } LOGGER.info("Controller is started."); } @@ -311,6 +395,16 @@ public void stop() { unusedValueSchemaCleanupService.ifPresent(Utils::closeQuietlyWithErrorLogged); storeBackupVersionCleanupService.ifPresent(Utils::closeQuietlyWithErrorLogged); disabledPartitionEnablerService.ifPresent(Utils::closeQuietlyWithErrorLogged); + if (adminGrpcServer != null) { + adminGrpcServer.stop(); + } + if (adminSecureGrpcServer != null) { + adminSecureGrpcServer.stop(); + } + if (grpcExecutor != null) { + LOGGER.info("Shutting down gRPC executor"); + grpcExecutor.shutdown(); + } Utils.closeQuietlyWithErrorLogged(topicCleanupService); Utils.closeQuietlyWithErrorLogged(secureAdminServer); Utils.closeQuietlyWithErrorLogged(adminServer); diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerClusterConfig.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerClusterConfig.java index 827c9024e37..d7514d7c413 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerClusterConfig.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerClusterConfig.java @@ -26,6 +26,8 @@ import static com.linkedin.venice.ConfigKeys.CLUSTER_TO_D2; import static com.linkedin.venice.ConfigKeys.CLUSTER_TO_SERVER_D2; import static com.linkedin.venice.ConfigKeys.CONCURRENT_INIT_ROUTINES_ENABLED; +import static com.linkedin.venice.ConfigKeys.CONTROLLER_ADMIN_GRPC_PORT; +import static com.linkedin.venice.ConfigKeys.CONTROLLER_ADMIN_SECURE_GRPC_PORT; import static com.linkedin.venice.ConfigKeys.CONTROLLER_AUTO_MATERIALIZE_DAVINCI_PUSH_STATUS_SYSTEM_STORE; import static com.linkedin.venice.ConfigKeys.CONTROLLER_AUTO_MATERIALIZE_META_SYSTEM_STORE; import static com.linkedin.venice.ConfigKeys.CONTROLLER_BACKUP_VERSION_DEFAULT_RETENTION_MS; @@ -47,6 +49,8 @@ import static com.linkedin.venice.ConfigKeys.CONTROLLER_EARLY_DELETE_BACKUP_ENABLED; import static com.linkedin.venice.ConfigKeys.CONTROLLER_ENABLE_DISABLED_REPLICA_ENABLED; import static com.linkedin.venice.ConfigKeys.CONTROLLER_ENFORCE_SSL; +import static com.linkedin.venice.ConfigKeys.CONTROLLER_GRPC_SERVER_ENABLED; +import static com.linkedin.venice.ConfigKeys.CONTROLLER_GRPC_SERVER_THREAD_COUNT; import static com.linkedin.venice.ConfigKeys.CONTROLLER_HAAS_SUPER_CLUSTER_NAME; import static com.linkedin.venice.ConfigKeys.CONTROLLER_HELIX_CLOUD_ID; import static com.linkedin.venice.ConfigKeys.CONTROLLER_HELIX_CLOUD_INFO_PROCESSOR_NAME; @@ -233,6 +237,8 @@ public class VeniceControllerClusterConfig { private final String adminHostname; private final int adminPort; private final int adminSecurePort; + private final int adminGrpcPort; + private final int adminSecureGrpcPort; private final int controllerClusterReplica; // Name of the Helix cluster for controllers private final String controllerClusterName; @@ -290,6 +296,8 @@ public class VeniceControllerClusterConfig { private final boolean backupVersionRetentionBasedCleanupEnabled; private final boolean backupVersionMetadataFetchBasedCleanupEnabled; + private final boolean grpcServerEnabled; + private final int grpcServerThreadCount; private final boolean enforceSSLOnly; private final long terminalStateTopicCheckerDelayMs; private final List disabledRoutes; @@ -657,6 +665,12 @@ public VeniceControllerClusterConfig(VeniceProperties props) { this.adminPort = props.getInt(ADMIN_PORT); this.adminHostname = props.getString(ADMIN_HOSTNAME, Utils::getHostName); this.adminSecurePort = props.getInt(ADMIN_SECURE_PORT); + this.adminGrpcPort = props.getInt(CONTROLLER_ADMIN_GRPC_PORT, -1); + this.adminSecureGrpcPort = props.getInt(CONTROLLER_ADMIN_SECURE_GRPC_PORT, -1); + this.grpcServerEnabled = props.getBoolean(CONTROLLER_GRPC_SERVER_ENABLED, false); + this.grpcServerThreadCount = + props.getInt(CONTROLLER_GRPC_SERVER_THREAD_COUNT, Runtime.getRuntime().availableProcessors()); + /** * Override the config to false if the "Read" method check is not working as expected. */ @@ -854,8 +868,8 @@ public VeniceControllerClusterConfig(VeniceProperties props) { props.getBoolean(CONTROLLER_BACKUP_VERSION_RETENTION_BASED_CLEANUP_ENABLED, false); this.backupVersionMetadataFetchBasedCleanupEnabled = props.getBoolean(CONTROLLER_BACKUP_VERSION_METADATA_FETCH_BASED_CLEANUP_ENABLED, false); - this.enforceSSLOnly = props.getBoolean(CONTROLLER_ENFORCE_SSL, false); // By default, allow both secure and insecure - // routes + // By default, allow both secure and insecure routes + this.enforceSSLOnly = props.getBoolean(CONTROLLER_ENFORCE_SSL, false); this.terminalStateTopicCheckerDelayMs = props.getLong(TERMINAL_STATE_TOPIC_CHECK_DELAY_MS, TimeUnit.MINUTES.toMillis(10)); this.disableParentTopicTruncationUponCompletion = @@ -1226,6 +1240,14 @@ public int getAdminSecurePort() { return adminSecurePort; } + public int getAdminGrpcPort() { + return adminGrpcPort; + } + + public int getAdminSecureGrpcPort() { + return adminSecureGrpcPort; + } + public boolean adminCheckReadMethodForKafka() { return adminCheckReadMethodForKafka; } @@ -1472,6 +1494,14 @@ public boolean isControllerEnforceSSLOnly() { return enforceSSLOnly; } + public boolean isGrpcServerEnabled() { + return grpcServerEnabled; + } + + public int getGrpcServerThreadCount() { + return grpcServerThreadCount; + } + public long getTerminalStateTopicCheckerDelayMs() { return terminalStateTopicCheckerDelayMs; } diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerMultiClusterConfig.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerMultiClusterConfig.java index 2f9ea6d8a5b..0de71246bae 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerMultiClusterConfig.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceControllerMultiClusterConfig.java @@ -55,6 +55,14 @@ public int getAdminSecurePort() { return getCommonConfig().getAdminSecurePort(); } + public int getAdminGrpcPort() { + return getCommonConfig().getAdminGrpcPort(); + } + + public int getAdminSecureGrpcPort() { + return getCommonConfig().getAdminSecureGrpcPort(); + } + public boolean adminCheckReadMethodForKafka() { return getCommonConfig().adminCheckReadMethodForKafka(); } @@ -219,6 +227,14 @@ public boolean isControllerEnforceSSLOnly() { return getCommonConfig().isControllerEnforceSSLOnly(); } + public boolean isGrpcServerEnabled() { + return getCommonConfig().isGrpcServerEnabled(); + } + + public int getGrpcServerThreadCount() { + return getCommonConfig().getGrpcServerThreadCount(); + } + public long getTerminalStateTopicCheckerDelayMs() { return getCommonConfig().getTerminalStateTopicCheckerDelayMs(); } diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceHelixAdmin.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceHelixAdmin.java index 66329256d61..5f1bcd0419f 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceHelixAdmin.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/VeniceHelixAdmin.java @@ -3885,6 +3885,7 @@ private boolean truncateKafkaTopic( String kafkaTopicName, long deprecatedJobTopicRetentionMs) { try { + if (topicManager .updateTopicRetention(pubSubTopicRepository.getTopic(kafkaTopicName), deprecatedJobTopicRetentionMs)) { return true; @@ -6773,7 +6774,9 @@ public Instance getLeaderController(String clusterName) { id, Utils.parseHostFromHelixNodeIdentifier(id), Utils.parsePortFromHelixNodeIdentifier(id), - multiClusterConfigs.getAdminSecurePort()); + multiClusterConfigs.getAdminSecurePort(), + multiClusterConfigs.getAdminGrpcPort(), + multiClusterConfigs.getAdminSecureGrpcPort()); } } if (attempt < maxAttempts) { diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/AdminSparkServer.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/AdminSparkServer.java index c77d7fafab5..25b73ee0ac4 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/AdminSparkServer.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/AdminSparkServer.java @@ -170,6 +170,7 @@ public class AdminSparkServer extends AbstractVeniceService { private final boolean disableParentRequestTopicForStreamPushes; private final PubSubTopicRepository pubSubTopicRepository; + private final VeniceControllerRequestHandler requestHandler; public AdminSparkServer( int port, @@ -183,13 +184,15 @@ public AdminSparkServer( List disabledRoutes, VeniceProperties jettyConfigOverrides, boolean disableParentRequestTopicForStreamPushes, - PubSubTopicRepository pubSubTopicRepository) { + PubSubTopicRepository pubSubTopicRepository, + VeniceControllerRequestHandler requestHandler) { this.port = port; this.enforceSSL = enforceSSL; this.sslEnabled = sslConfig.isPresent(); this.sslConfig = sslConfig; this.checkReadMethodForKafka = checkReadMethodForKafka; this.accessController = accessController; + this.requestHandler = requestHandler; // Note: admin is passed in as a reference. The expectation is the source of the admin will // close it so we don't close it in stopInner() this.admin = admin; @@ -279,7 +282,8 @@ public boolean startInner() throws Exception { }); // Build all different routes - ControllerRoutes controllerRoutes = new ControllerRoutes(sslEnabled, accessController, pubSubTopicRepository); + ControllerRoutes controllerRoutes = + new ControllerRoutes(sslEnabled, accessController, pubSubTopicRepository, requestHandler); StoresRoutes storesRoutes = new StoresRoutes(sslEnabled, accessController, pubSubTopicRepository); JobRoutes jobRoutes = new JobRoutes(sslEnabled, accessController); SkipAdminRoute skipAdminRoute = new SkipAdminRoute(sslEnabled, accessController); @@ -362,7 +366,7 @@ public boolean startInner() throws Exception { new VeniceParentControllerRegionStateHandler(admin, createVersion.addVersionAndStartIngestion(admin))); httpService.post( NEW_STORE.getPath(), - new VeniceParentControllerRegionStateHandler(admin, createStoreRoute.createStore(admin))); + new VeniceParentControllerRegionStateHandler(admin, createStoreRoute.createStore(admin, requestHandler))); httpService.get( CHECK_RESOURCE_CLEANUP_FOR_STORE_CREATION.getPath(), new VeniceParentControllerRegionStateHandler( @@ -529,7 +533,7 @@ public boolean startInner() throws Exception { httpService.get( CLUSTER_DISCOVERY.getPath(), - new VeniceParentControllerRegionStateHandler(admin, ClusterDiscovery.discoverCluster(admin))); + new VeniceParentControllerRegionStateHandler(admin, ClusterDiscovery.discoverCluster(admin, requestHandler))); httpService.get( LIST_BOOTSTRAPPING_VERSIONS.getPath(), new VeniceParentControllerRegionStateHandler(admin, versionRoute.listBootstrappingVersions(admin))); diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ClusterDiscovery.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ClusterDiscovery.java index 7ddd12b55ae..6a9cf2e076f 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ClusterDiscovery.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ClusterDiscovery.java @@ -6,7 +6,7 @@ import com.linkedin.venice.HttpConstants; import com.linkedin.venice.controller.Admin; import com.linkedin.venice.controllerapi.D2ServiceDiscoveryResponse; -import com.linkedin.venice.utils.Pair; +import com.linkedin.venice.controllerapi.request.ClusterDiscoveryRequest; import spark.Route; @@ -14,16 +14,13 @@ public class ClusterDiscovery { /** * No ACL check; any user is allowed to discover cluster */ - public static Route discoverCluster(Admin admin) { + public static Route discoverCluster(Admin admin, VeniceControllerRequestHandler requestHandler) { return (request, response) -> { D2ServiceDiscoveryResponse responseObject = new D2ServiceDiscoveryResponse(); try { AdminSparkServer.validateParams(request, CLUSTER_DISCOVERY.getParams(), admin); - responseObject.setName(request.queryParams(NAME)); - Pair clusterToD2Pair = admin.discoverCluster(responseObject.getName()); - responseObject.setCluster(clusterToD2Pair.getFirst()); - responseObject.setD2Service(clusterToD2Pair.getSecond()); - responseObject.setServerD2Service(admin.getServerD2Service(clusterToD2Pair.getFirst())); + ClusterDiscoveryRequest requestObject = new ClusterDiscoveryRequest(request.queryParams(NAME)); + requestHandler.discoverCluster(requestObject, responseObject); } catch (Throwable e) { responseObject.setError(e); AdminSparkServer.handleError(e, request, response); diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ControllerRoutes.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ControllerRoutes.java index 5b1e1e7fad8..9a8e0214988 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ControllerRoutes.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/ControllerRoutes.java @@ -23,8 +23,8 @@ import com.linkedin.venice.controllerapi.LeaderControllerResponse; import com.linkedin.venice.controllerapi.PubSubTopicConfigResponse; import com.linkedin.venice.controllerapi.StoppableNodeStatusResponse; +import com.linkedin.venice.controllerapi.request.ControllerRequest; import com.linkedin.venice.exceptions.ErrorType; -import com.linkedin.venice.meta.Instance; import com.linkedin.venice.pubsub.PubSubTopicConfiguration; import com.linkedin.venice.pubsub.PubSubTopicRepository; import com.linkedin.venice.pubsub.api.PubSubTopic; @@ -45,13 +45,16 @@ public class ControllerRoutes extends AbstractRoute { private static final ObjectMapper OBJECT_MAPPER = ObjectMapperFactory.getInstance(); private final PubSubTopicRepository pubSubTopicRepository; + private final VeniceControllerRequestHandler requestHandler; public ControllerRoutes( boolean sslEnabled, Optional accessController, - PubSubTopicRepository pubSubTopicRepository) { + PubSubTopicRepository pubSubTopicRepository, + VeniceControllerRequestHandler requestHandler) { super(sslEnabled, accessController); this.pubSubTopicRepository = pubSubTopicRepository; + this.requestHandler = requestHandler; } /** @@ -64,13 +67,7 @@ public Route getLeaderController(Admin admin) { try { AdminSparkServer.validateParams(request, LEADER_CONTROLLER.getParams(), admin); String cluster = request.queryParams(CLUSTER); - responseObject.setCluster(cluster); - Instance leaderController = admin.getLeaderController(cluster); - responseObject.setUrl(leaderController.getUrl(isSslEnabled())); - if (leaderController.getPort() != leaderController.getSslPort()) { - // Controller is SSL Enabled - responseObject.setSecureUrl(leaderController.getUrl(true)); - } + requestHandler.getLeaderController(new ControllerRequest(cluster), responseObject); } catch (Throwable e) { responseObject.setError(e); AdminSparkServer.handleError(e, request, response); diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/CreateStore.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/CreateStore.java index c0ff0ff27b5..88d6f4a8a1f 100644 --- a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/CreateStore.java +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/CreateStore.java @@ -18,6 +18,7 @@ import com.linkedin.venice.controllerapi.AclResponse; import com.linkedin.venice.controllerapi.ControllerResponse; import com.linkedin.venice.controllerapi.NewStoreResponse; +import com.linkedin.venice.controllerapi.request.CreateNewStoreRequest; import java.util.Optional; import spark.Request; import spark.Route; @@ -31,7 +32,7 @@ public CreateStore(boolean sslEnabled, Optional accessC /** * @see Admin#createStore(String, String, String, String, String, boolean, Optional) */ - public Route createStore(Admin admin) { + public Route createStore(Admin admin, VeniceControllerRequestHandler requestHandler) { return new VeniceRouteHandler(NewStoreResponse.class) { @Override public void internalHandle(Request request, NewStoreResponse veniceResponse) { @@ -39,25 +40,18 @@ public void internalHandle(Request request, NewStoreResponse veniceResponse) { if (!checkIsAllowListUser(request, veniceResponse, () -> isAllowListUser(request))) { return; } + // Validate request parameters AdminSparkServer.validateParams(request, NEW_STORE.getParams(), admin); - String clusterName = request.queryParams(CLUSTER); - String storeName = request.queryParams(NAME); - String keySchema = request.queryParams(KEY_SCHEMA); - String valueSchema = request.queryParams(VALUE_SCHEMA); - boolean isSystemStore = Boolean.parseBoolean(request.queryParams(IS_SYSTEM_STORE)); - - String owner = AdminSparkServer.getOptionalParameterValue(request, OWNER); - if (owner == null) { - owner = ""; - } - - String accessPerm = request.queryParams(ACCESS_PERMISSION); - Optional accessPermissions = Optional.ofNullable(accessPerm); - - veniceResponse.setCluster(clusterName); - veniceResponse.setName(storeName); - veniceResponse.setOwner(owner); - admin.createStore(clusterName, storeName, owner, keySchema, valueSchema, isSystemStore, accessPermissions); + // Extract the parameters from the spark request and create the generic request object + CreateNewStoreRequest storeRequest = new CreateNewStoreRequest( + request.queryParams(CLUSTER), + request.queryParams(NAME), + AdminSparkServer.getOptionalParameterValue(request, OWNER), + request.queryParams(KEY_SCHEMA), + request.queryParams(VALUE_SCHEMA), + request.queryParams(ACCESS_PERMISSION), + Boolean.parseBoolean(request.queryParams(IS_SYSTEM_STORE))); + requestHandler.createStore(storeRequest, veniceResponse); } }; } diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/VeniceControllerGrpcServiceImpl.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/VeniceControllerGrpcServiceImpl.java new file mode 100644 index 00000000000..b3c5d8a3bb5 --- /dev/null +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/VeniceControllerGrpcServiceImpl.java @@ -0,0 +1,183 @@ +package com.linkedin.venice.controller.server; + +import static com.linkedin.venice.controllerapi.transport.GrpcRequestResponseConverter.getClusterStoreGrpcInfo; + +import com.linkedin.venice.controllerapi.D2ServiceDiscoveryResponse; +import com.linkedin.venice.controllerapi.LeaderControllerResponse; +import com.linkedin.venice.controllerapi.NewStoreResponse; +import com.linkedin.venice.controllerapi.request.ClusterDiscoveryRequest; +import com.linkedin.venice.controllerapi.request.ControllerRequest; +import com.linkedin.venice.controllerapi.request.CreateNewStoreRequest; +import com.linkedin.venice.controllerapi.transport.GrpcRequestResponseConverter; +import com.linkedin.venice.protocols.controller.ControllerGrpcErrorType; +import com.linkedin.venice.protocols.controller.CreateStoreGrpcRequest; +import com.linkedin.venice.protocols.controller.CreateStoreGrpcResponse; +import com.linkedin.venice.protocols.controller.DiscoverClusterGrpcRequest; +import com.linkedin.venice.protocols.controller.DiscoverClusterGrpcResponse; +import com.linkedin.venice.protocols.controller.LeaderControllerGrpcRequest; +import com.linkedin.venice.protocols.controller.LeaderControllerGrpcResponse; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc.VeniceControllerGrpcServiceImplBase; +import io.grpc.Status.Code; +import io.grpc.stub.StreamObserver; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + + +/** + * This class is a gRPC service implementation for the VeniceController public API. + */ +public class VeniceControllerGrpcServiceImpl extends VeniceControllerGrpcServiceImplBase { + private static final Logger LOGGER = LogManager.getLogger(VeniceControllerGrpcServiceImpl.class); + + private final VeniceControllerRequestHandler requestHandler; + + public VeniceControllerGrpcServiceImpl(VeniceControllerRequestHandler requestHandler) { + this.requestHandler = requestHandler; + } + + @Override + public void getLeaderController( + LeaderControllerGrpcRequest request, + StreamObserver responseObserver) { + String clusterName = request.getClusterName(); + LOGGER.info("Received gRPC request to get leader controller for cluster: {}", clusterName); + try { + LeaderControllerResponse response = new LeaderControllerResponse(); + ControllerRequest controllerRequest = new ControllerRequest(clusterName); + requestHandler.getLeaderController(controllerRequest, response); + LeaderControllerGrpcResponse.Builder grpcResponseBuilder = + LeaderControllerGrpcResponse.newBuilder().setClusterName(response.getCluster()).setHttpUrl(response.getUrl()); + + if (response.getSecureUrl() != null) { + grpcResponseBuilder.setHttpsUrl(response.getSecureUrl()); + } + if (response.getGrpcUrl() != null) { + grpcResponseBuilder.setGrpcUrl(response.getGrpcUrl()); + } + if (response.getSecureGrpcUrl() != null) { + grpcResponseBuilder.setSecureGrpcUrl(response.getSecureGrpcUrl()); + } + responseObserver.onNext(grpcResponseBuilder.build()); + responseObserver.onCompleted(); + } catch (IllegalArgumentException e) { + LOGGER.error("Invalid argument while getting leader controller for cluster: {}", clusterName, e); + GrpcRequestResponseConverter.sendErrorResponse( + Code.INVALID_ARGUMENT, + ControllerGrpcErrorType.BAD_REQUEST, + e, + clusterName, + null, + responseObserver); + } catch (Exception e) { + LOGGER.error("Error while getting leader controller for cluster: {}", clusterName, e); + GrpcRequestResponseConverter.sendErrorResponse( + Code.INTERNAL, + ControllerGrpcErrorType.GENERAL_ERROR, + e, + clusterName, + null, + responseObserver); + } + } + + @Override + public void discoverClusterForStore( + DiscoverClusterGrpcRequest grpcRequest, + StreamObserver responseObserver) { + String storeName = grpcRequest.getStoreName(); + LOGGER.info("Received gRPC request to discover cluster for store: {}", storeName); + try { + D2ServiceDiscoveryResponse response = new D2ServiceDiscoveryResponse(); + requestHandler.discoverCluster(new ClusterDiscoveryRequest(grpcRequest.getStoreName()), response); + DiscoverClusterGrpcResponse.Builder responseBuilder = + DiscoverClusterGrpcResponse.newBuilder().setStoreName(response.getName()); + if (response.getCluster() != null) { + responseBuilder.setClusterName(response.getCluster()); + } + if (response.getD2Service() != null) { + responseBuilder.setD2Service(response.getD2Service()); + } + if (response.getServerD2Service() != null) { + responseBuilder.setServerD2Service(response.getServerD2Service()); + } + if (response.getZkAddress() != null) { + responseBuilder.setZkAddress(response.getZkAddress()); + } + if (response.getKafkaBootstrapServers() != null) { + responseBuilder.setPubSubBootstrapServers(response.getKafkaBootstrapServers()); + } + responseObserver.onNext(responseBuilder.build()); + responseObserver.onCompleted(); + } catch (IllegalArgumentException e) { + LOGGER.error("Invalid argument while discovering cluster for store: {}", storeName, e); + GrpcRequestResponseConverter.sendErrorResponse( + Code.INVALID_ARGUMENT, + ControllerGrpcErrorType.BAD_REQUEST, + e, + null, + storeName, + responseObserver); + } catch (Exception e) { + LOGGER.error("Error while discovering cluster for store: {}", storeName, e); + GrpcRequestResponseConverter.sendErrorResponse( + Code.INTERNAL, + ControllerGrpcErrorType.GENERAL_ERROR, + e, + null, + storeName, + responseObserver); + } + } + + @Override + public void createStore( + CreateStoreGrpcRequest grpcRequest, + StreamObserver responseObserver) { + String clusterName = grpcRequest.getClusterStoreInfo().getClusterName(); + String storeName = grpcRequest.getClusterStoreInfo().getStoreName(); + LOGGER.info("Received gRPC request to create store: {} in cluster: {}", storeName, clusterName); + try { + // TODO (sushantmane) : Add the ACL check for allowlist users here + + // Convert the gRPC request to the internal request object + CreateNewStoreRequest request = new CreateNewStoreRequest( + grpcRequest.getClusterStoreInfo().getClusterName(), + grpcRequest.getClusterStoreInfo().getStoreName(), + grpcRequest.hasOwner() ? grpcRequest.getOwner() : null, + grpcRequest.getKeySchema(), + grpcRequest.getValueSchema(), + grpcRequest.hasAccessPermission() ? grpcRequest.getAccessPermission() : null, + grpcRequest.getIsSystemStore()); + + // Create the store using the internal request object + NewStoreResponse response = new NewStoreResponse(); + requestHandler.createStore(request, response); + + // Convert the internal response object to the gRPC response object and send the gRPC response + responseObserver.onNext( + CreateStoreGrpcResponse.newBuilder() + .setClusterStoreInfo(getClusterStoreGrpcInfo(response)) + .setOwner(response.getOwner()) + .build()); + responseObserver.onCompleted(); + } catch (IllegalArgumentException e) { + LOGGER.error("Invalid argument while creating store: {} in cluster: {}", storeName, clusterName, e); + GrpcRequestResponseConverter.sendErrorResponse( + Code.INVALID_ARGUMENT, + ControllerGrpcErrorType.BAD_REQUEST, + e, + clusterName, + storeName, + responseObserver); + } catch (Exception e) { + LOGGER.error("Error while creating store: {} in cluster: {}", storeName, clusterName, e); + GrpcRequestResponseConverter.sendErrorResponse( + Code.INTERNAL, + ControllerGrpcErrorType.GENERAL_ERROR, + e, + clusterName, + storeName, + responseObserver); + } + } +} diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/VeniceControllerRequestHandler.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/VeniceControllerRequestHandler.java new file mode 100644 index 00000000000..03290607392 --- /dev/null +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/VeniceControllerRequestHandler.java @@ -0,0 +1,97 @@ +package com.linkedin.venice.controller.server; + +import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ControllerRequestHandlerDependencies; +import com.linkedin.venice.controllerapi.D2ServiceDiscoveryResponse; +import com.linkedin.venice.controllerapi.LeaderControllerResponse; +import com.linkedin.venice.controllerapi.NewStoreResponse; +import com.linkedin.venice.controllerapi.request.ClusterDiscoveryRequest; +import com.linkedin.venice.controllerapi.request.ControllerRequest; +import com.linkedin.venice.controllerapi.request.CreateNewStoreRequest; +import com.linkedin.venice.meta.Instance; +import com.linkedin.venice.utils.Pair; +import java.util.Optional; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + + +/** + * The core handler for processing incoming requests in the VeniceController. + * Acts as the central entry point for handling requests received via both HTTP/REST and gRPC protocols. + * This class is responsible for managing all request handling operations for the VeniceController. + */ +public class VeniceControllerRequestHandler { + private static final Logger LOGGER = LogManager.getLogger(VeniceControllerGrpcServiceImpl.class); + private final Admin admin; + private final boolean sslEnabled; + + public VeniceControllerRequestHandler(ControllerRequestHandlerDependencies dependencies) { + this.admin = dependencies.getAdmin(); + this.sslEnabled = dependencies.isSslEnabled(); + } + + // visibility: package-private + boolean isSslEnabled() { + return sslEnabled; + } + + /** + * The response is passed as an argument to avoid creating duplicate response objects for HTTP requests + * and to simplify unit testing with gRPC. Once the transition to gRPC is complete, we can eliminate + * the need to pass the response as an argument and instead construct and return it directly within the method. + */ + public void getLeaderController(ControllerRequest request, LeaderControllerResponse response) { + String clusterName = request.getClusterName(); + response.setCluster(clusterName); + + Instance leaderControllerInstance = admin.getLeaderController(clusterName); + response.setUrl(leaderControllerInstance.getUrl(isSslEnabled())); + if (leaderControllerInstance.getPort() != leaderControllerInstance.getSslPort()) { + // Controller is SSL Enabled + response.setSecureUrl(leaderControllerInstance.getUrl(true)); + } + response.setGrpcUrl(leaderControllerInstance.getGrpcUrl()); + response.setSecureGrpcUrl(leaderControllerInstance.getGrpcSslUrl()); + } + + public void discoverCluster(ClusterDiscoveryRequest request, D2ServiceDiscoveryResponse response) { + String storeName = request.getStoreName(); + LOGGER.info("Discovering cluster for store: {}", storeName); + Pair clusterToD2Pair = admin.discoverCluster(storeName); + response.setName(storeName); + response.setCluster(clusterToD2Pair.getFirst()); + response.setD2Service(clusterToD2Pair.getSecond()); + response.setServerD2Service(admin.getServerD2Service(clusterToD2Pair.getFirst())); + } + + /** + * Creates a new store in the specified Venice cluster with the provided parameters. + * @param request the request object containing all necessary details for the creation of the store + */ + public void createStore(CreateNewStoreRequest request, NewStoreResponse response) { + String clusterName = request.getClusterName(); + String storeName = request.getStoreName(); + String keySchema = request.getKeySchema(); + String valueSchema = request.getValueSchema(); + String owner = request.getOwner(); + Optional accessPermissions = Optional.ofNullable(request.getAccessPermissions()); + boolean isSystemStore = request.isSystemStore(); + + LOGGER.info( + "Creating store: {} in cluster: {} with owner: {} and key schema: {} and value schema: {} and isSystemStore: {} and access permissions: {}", + storeName, + clusterName, + owner, + keySchema, + valueSchema, + isSystemStore, + accessPermissions); + + admin.createStore(clusterName, storeName, owner, keySchema, valueSchema, isSystemStore, accessPermissions); + + response.setCluster(clusterName); + response.setName(storeName); + response.setOwner(owner); + LOGGER.info("Successfully created store: {} in cluster: {}", storeName, clusterName); + } +} diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/grpc/ControllerSslSessionInterceptor.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/grpc/ControllerSslSessionInterceptor.java new file mode 100644 index 00000000000..9bac6744535 --- /dev/null +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/grpc/ControllerSslSessionInterceptor.java @@ -0,0 +1,71 @@ +package com.linkedin.venice.controller.server.grpc; + +import com.linkedin.venice.grpc.GrpcUtils; +import com.linkedin.venice.protocols.controller.ControllerGrpcErrorType; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcErrorInfo; +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.StatusProto; +import java.net.SocketAddress; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLSession; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + + +public class ControllerSslSessionInterceptor implements ServerInterceptor { + private static final Logger LOGGER = LogManager.getLogger(ControllerSslSessionInterceptor.class); + + public static final Context.Key CLIENT_CERTIFICATE_CONTEXT_KEY = + Context.key("controller-client-certificate"); + public static final Context.Key CLIENT_ADDRESS_CONTEXT_KEY = Context.key("controller-client-address"); + + private static final VeniceControllerGrpcErrorInfo NON_SSL_ERROR_INFO = VeniceControllerGrpcErrorInfo.newBuilder() + .setStatusCode(Status.UNAUTHENTICATED.getCode().value()) + .setErrorType(ControllerGrpcErrorType.CONNECTION_ERROR) + .setErrorMessage("SSL connection required") + .build(); + + private static final StatusRuntimeException NON_SSL_CONNECTION_STATUS = StatusProto.toStatusRuntimeException( + com.google.rpc.Status.newBuilder() + .setCode(Status.UNAUTHENTICATED.getCode().value()) + .addDetails(com.google.protobuf.Any.pack(NON_SSL_ERROR_INFO)) + .build()); + + @Override + public io.grpc.ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler serverCallHandler) { + SocketAddress remoteAddress = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + String remoteAddressStr = remoteAddress != null ? remoteAddress.toString() : "unknown"; + SSLSession sslSession = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION); + if (sslSession == null) { + LOGGER.debug("SSL not enabled"); + serverCall.close(NON_SSL_CONNECTION_STATUS.getStatus(), NON_SSL_CONNECTION_STATUS.getTrailers()); + return new ServerCall.Listener() { + }; + } + + X509Certificate clientCert; + try { + clientCert = GrpcUtils.extractGrpcClientCert(serverCall); + } catch (Exception e) { + LOGGER.error("Failed to extract client certificate", e); + serverCall.close(NON_SSL_CONNECTION_STATUS.getStatus(), NON_SSL_CONNECTION_STATUS.getTrailers()); + return new ServerCall.Listener() { + }; + } + Context context = Context.current() + .withValue(CLIENT_CERTIFICATE_CONTEXT_KEY, clientCert) + .withValue(CLIENT_ADDRESS_CONTEXT_KEY, remoteAddressStr); + return Contexts.interceptCall(context, serverCall, metadata, serverCallHandler); + } +} diff --git a/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/grpc/ParentControllerRegionValidationInterceptor.java b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/grpc/ParentControllerRegionValidationInterceptor.java new file mode 100644 index 00000000000..13222dda746 --- /dev/null +++ b/services/venice-controller/src/main/java/com/linkedin/venice/controller/server/grpc/ParentControllerRegionValidationInterceptor.java @@ -0,0 +1,67 @@ +package com.linkedin.venice.controller.server.grpc; + +import static com.linkedin.venice.controller.ParentControllerRegionState.ACTIVE; +import static com.linkedin.venice.controller.server.VeniceParentControllerRegionStateHandler.ACTIVE_CHECK_FAILURE_WARN_MESSAGE_PREFIX; + +import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ParentControllerRegionState; +import com.linkedin.venice.protocols.controller.ControllerGrpcErrorType; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcErrorInfo; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.StatusProto; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + + +/** + * Interceptor to verify that the parent controller is active before processing requests within its region. + */ +public class ParentControllerRegionValidationInterceptor implements ServerInterceptor { + private static final Logger LOGGER = LogManager.getLogger(ParentControllerRegionValidationInterceptor.class); + private static final VeniceControllerGrpcErrorInfo.Builder ERROR_INFO_BUILDER = + VeniceControllerGrpcErrorInfo.newBuilder() + .setErrorType(ControllerGrpcErrorType.INCORRECT_CONTROLLER) + .setStatusCode(Status.FAILED_PRECONDITION.getCode().value()); + + private static final com.google.rpc.Status.Builder RPC_STATUS_BUILDER = com.google.rpc.Status.newBuilder() + .setCode(Status.FAILED_PRECONDITION.getCode().value()) + .setMessage("Parent controller is not active"); + + private final Admin admin; + + public ParentControllerRegionValidationInterceptor(Admin admin) { + this.admin = admin; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + ParentControllerRegionState parentControllerRegionState = admin.getParentControllerRegionState(); + boolean isParent = admin.isParent(); + if (isParent && parentControllerRegionState != ACTIVE) { + LOGGER.debug( + "Parent controller is not active. Rejecting the request: {} from source: {}", + call.getMethodDescriptor().getFullMethodName(), + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + // Retrieve the full method name + String fullMethodName = call.getMethodDescriptor().getFullMethodName(); + VeniceControllerGrpcErrorInfo errorInfo = + ERROR_INFO_BUILDER.setErrorMessage(ACTIVE_CHECK_FAILURE_WARN_MESSAGE_PREFIX + ": " + fullMethodName).build(); + // Note: On client side convert FAILED_PRECONDITION to SC_MISDIRECTED_REQUEST + com.google.rpc.Status rpcStatus = RPC_STATUS_BUILDER.addDetails(com.google.protobuf.Any.pack(errorInfo)).build(); + StatusRuntimeException exception = StatusProto.toStatusRuntimeException(rpcStatus); + call.close(exception.getStatus(), exception.getTrailers()); + return new ServerCall.Listener() { + }; + } + return next.startCall(call, headers); + } +} diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/ControllerRequestHandlerDependenciesTest.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/ControllerRequestHandlerDependenciesTest.java new file mode 100644 index 00000000000..103e5e9962f --- /dev/null +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/ControllerRequestHandlerDependenciesTest.java @@ -0,0 +1,98 @@ +package com.linkedin.venice.controller; + +import static org.mockito.Mockito.*; +import static org.testng.Assert.*; + +import com.linkedin.venice.SSLConfig; +import com.linkedin.venice.acl.DynamicAccessController; +import com.linkedin.venice.controllerapi.ControllerRoute; +import com.linkedin.venice.pubsub.PubSubTopicRepository; +import com.linkedin.venice.utils.VeniceProperties; +import io.tehuti.metrics.MetricsRepository; +import java.util.Collections; +import org.testng.annotations.Test; + + +public class ControllerRequestHandlerDependenciesTest { + @Test + public void testBuilderWithAllFieldsSet() { + Admin admin = mock(Admin.class); + SSLConfig sslConfig = mock(SSLConfig.class); + DynamicAccessController accessController = mock(DynamicAccessController.class); + PubSubTopicRepository pubSubTopicRepository = mock(PubSubTopicRepository.class); + MetricsRepository metricsRepository = mock(MetricsRepository.class); + VeniceProperties veniceProperties = mock(VeniceProperties.class); + ControllerRoute route = ControllerRoute.STORE; + + ControllerRequestHandlerDependencies dependencies = + new ControllerRequestHandlerDependencies.Builder().setAdmin(admin) + .setClusters(Collections.singleton("testCluster")) + .setEnforceSSL(true) + .setSslEnabled(true) + .setCheckReadMethodForKafka(true) + .setSslConfig(sslConfig) + .setAccessController(accessController) + .setDisabledRoutes(Collections.singletonList(route)) + .setDisableParentRequestTopicForStreamPushes(true) + .setPubSubTopicRepository(pubSubTopicRepository) + .setMetricsRepository(metricsRepository) + .setVeniceProperties(veniceProperties) + .build(); + + assertEquals(dependencies.getAdmin(), admin); + assertEquals(dependencies.getClusters(), Collections.singleton("testCluster")); + assertTrue(dependencies.isEnforceSSL()); + assertTrue(dependencies.isSslEnabled()); + assertTrue(dependencies.isCheckReadMethodForKafka()); + assertEquals(dependencies.getSslConfig(), sslConfig); + assertEquals(dependencies.getAccessController(), accessController); + assertEquals(dependencies.getDisabledRoutes(), Collections.singletonList(route)); + assertTrue(dependencies.isDisableParentRequestTopicForStreamPushes()); + assertEquals(dependencies.getPubSubTopicRepository(), pubSubTopicRepository); + assertEquals(dependencies.getMetricsRepository(), metricsRepository); + assertEquals(dependencies.getVeniceProperties(), veniceProperties); + } + + @Test + public void testBuilderWithDefaultPubSubTopicRepository() { + Admin admin = mock(Admin.class); + + ControllerRequestHandlerDependencies dependencies = + new ControllerRequestHandlerDependencies.Builder().setAdmin(admin) + .setClusters(Collections.singleton("testCluster")) + .build(); + + assertNotNull(dependencies.getPubSubTopicRepository()); + } + + @Test + public void testBuilderWithMissingAdmin() { + // Expect exception when admin is missing + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new ControllerRequestHandlerDependencies.Builder().setClusters(Collections.singleton("testCluster")) + .build()); + + assertEquals(exception.getMessage(), "admin is mandatory dependencies for VeniceControllerRequestHandler"); + } + + @Test + public void testDefaultValues() { + Admin admin = mock(Admin.class); + + ControllerRequestHandlerDependencies dependencies = + new ControllerRequestHandlerDependencies.Builder().setAdmin(admin) + .setClusters(Collections.singleton("testCluster")) + .build(); + + assertFalse(dependencies.isEnforceSSL()); + assertFalse(dependencies.isSslEnabled()); + assertFalse(dependencies.isCheckReadMethodForKafka()); + assertNull(dependencies.getSslConfig()); + assertNull(dependencies.getAccessController()); + assertTrue(dependencies.getDisabledRoutes().isEmpty()); + assertFalse(dependencies.isDisableParentRequestTopicForStreamPushes()); + assertNull(dependencies.getMetricsRepository()); + assertNull(dependencies.getVeniceProperties()); + } +} diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/ClusterDiscoveryTest.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/ClusterDiscoveryTest.java new file mode 100644 index 00000000000..16748e17ae1 --- /dev/null +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/ClusterDiscoveryTest.java @@ -0,0 +1,69 @@ +package com.linkedin.venice.controller.server; + +import static com.linkedin.venice.controllerapi.ControllerRoute.CLUSTER_DISCOVERY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controllerapi.ControllerApiConstants; +import com.linkedin.venice.controllerapi.D2ServiceDiscoveryResponse; +import com.linkedin.venice.controllerapi.request.ClusterDiscoveryRequest; +import com.linkedin.venice.utils.ObjectMapperFactory; +import org.testng.annotations.Test; +import spark.QueryParamsMap; +import spark.Request; +import spark.Response; +import spark.Route; + + +public class ClusterDiscoveryTest { + @Test + public void testDiscoverCluster() throws Exception { + // Case 1: Store name is not provided + String clusterName = "test-cluster"; + String storeName = "test-store"; + String d2Service = "d2://test-service"; + String serverD2Service = "d2://test-server"; + + Admin admin = mock(Admin.class); + VeniceControllerRequestHandler requestHandler = mock(VeniceControllerRequestHandler.class); + + doAnswer(invocation -> { + D2ServiceDiscoveryResponse response = invocation.getArgument(1); + response.setName(storeName); + response.setCluster(clusterName); + response.setD2Service(d2Service); + response.setServerD2Service(serverD2Service); + return null; + }).when(requestHandler).discoverCluster(any(ClusterDiscoveryRequest.class), any(D2ServiceDiscoveryResponse.class)); + + Request request = mock(Request.class); + when(request.pathInfo()).thenReturn(CLUSTER_DISCOVERY.getPath()); + when(request.queryParams(eq(ControllerApiConstants.NAME))).thenReturn(storeName); + Response response = mock(Response.class); + + Route discoverCluster = ClusterDiscovery.discoverCluster(admin, requestHandler); + D2ServiceDiscoveryResponse d2ServiceDiscoveryResponse = ObjectMapperFactory.getInstance() + .readValue(discoverCluster.handle(request, response).toString(), D2ServiceDiscoveryResponse.class); + assertNotNull(d2ServiceDiscoveryResponse, "Response should not be null"); + assertEquals(d2ServiceDiscoveryResponse.getName(), storeName, "Store name should match"); + assertEquals(d2ServiceDiscoveryResponse.getCluster(), clusterName, "Cluster name should match"); + assertEquals(d2ServiceDiscoveryResponse.getD2Service(), d2Service, "D2 service should match"); + assertEquals(d2ServiceDiscoveryResponse.getServerD2Service(), serverD2Service, "Server D2 service should match"); + + // Case 2: Store name is not provided + QueryParamsMap queryParamsMap = mock(QueryParamsMap.class); + when(request.queryMap()).thenReturn(queryParamsMap); + when(request.queryParams(eq(ControllerApiConstants.NAME))).thenReturn(""); + D2ServiceDiscoveryResponse d2ServiceDiscoveryResponse2 = ObjectMapperFactory.getInstance() + .readValue(discoverCluster.handle(request, response).toString(), D2ServiceDiscoveryResponse.class); + assertNotNull(d2ServiceDiscoveryResponse2, "Response should not be null"); + assertTrue(d2ServiceDiscoveryResponse2.isError(), "Error should be present in the response"); + } +} diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/ControllerRoutesTest.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/ControllerRoutesTest.java index 5b4b31d76e4..053e32611a6 100644 --- a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/ControllerRoutesTest.java +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/ControllerRoutesTest.java @@ -5,11 +5,13 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; import com.fasterxml.jackson.databind.ObjectMapper; import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ControllerRequestHandlerDependencies; import com.linkedin.venice.controller.InstanceRemovableStatuses; import com.linkedin.venice.controller.VeniceParentHelixAdmin; import com.linkedin.venice.controllerapi.AggregatedHealthStatusRequest; @@ -24,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import spark.Request; import spark.Response; @@ -37,39 +40,62 @@ public class ControllerRoutesTest { private static final String TEST_HOST = "localhost"; private static final int TEST_PORT = 2181; private static final int TEST_SSL_PORT = 2182; + private static final int TEST_GRPC_PORT = 2183; + private static final int TEST_GRPC_SSL_PORT = 2184; private final PubSubTopicRepository pubSubTopicRepository = new PubSubTopicRepository(); + private VeniceControllerRequestHandler requestHandler; + private ControllerRequestHandlerDependencies mockDependencies; + private Admin mockAdmin; + + @BeforeMethod(alwaysRun = true) + public void setUp() { + mockAdmin = mock(VeniceParentHelixAdmin.class); + mockDependencies = mock(ControllerRequestHandlerDependencies.class); + doReturn(mockAdmin).when(mockDependencies).getAdmin(); + requestHandler = new VeniceControllerRequestHandler(mockDependencies); + } + @Test public void testGetLeaderController() throws Exception { - Admin mockAdmin = mock(VeniceParentHelixAdmin.class); doReturn(true).when(mockAdmin).isLeaderControllerFor(anyString()); - Instance leaderController = new Instance(TEST_NODE_ID, TEST_HOST, TEST_PORT, TEST_SSL_PORT); + Instance leaderController = + new Instance(TEST_NODE_ID, TEST_HOST, TEST_PORT, TEST_SSL_PORT, TEST_GRPC_PORT, TEST_GRPC_SSL_PORT); + doReturn(leaderController).when(mockAdmin).getLeaderController(anyString()); Request request = mock(Request.class); doReturn(TEST_CLUSTER).when(request).queryParams(eq(ControllerApiConstants.CLUSTER)); - Route leaderControllerRoute = - new ControllerRoutes(false, Optional.empty(), pubSubTopicRepository).getLeaderController(mockAdmin); + Route leaderControllerRoute = new ControllerRoutes(false, Optional.empty(), pubSubTopicRepository, requestHandler) + .getLeaderController(mockAdmin); LeaderControllerResponse leaderControllerResponse = OBJECT_MAPPER.readValue( leaderControllerRoute.handle(request, mock(Response.class)).toString(), LeaderControllerResponse.class); assertEquals(leaderControllerResponse.getCluster(), TEST_CLUSTER); assertEquals(leaderControllerResponse.getUrl(), "http://" + TEST_HOST + ":" + TEST_PORT); assertEquals(leaderControllerResponse.getSecureUrl(), "https://" + TEST_HOST + ":" + TEST_SSL_PORT); + assertEquals(leaderControllerResponse.getGrpcUrl(), TEST_HOST + ":" + TEST_GRPC_PORT); + assertEquals(leaderControllerResponse.getSecureGrpcUrl(), TEST_HOST + ":" + TEST_GRPC_SSL_PORT); + + when(mockDependencies.isSslEnabled()).thenReturn(true); + requestHandler = new VeniceControllerRequestHandler(mockDependencies); - Route leaderControllerSslRoute = - new ControllerRoutes(true, Optional.empty(), pubSubTopicRepository).getLeaderController(mockAdmin); + Route leaderControllerSslRoute = new ControllerRoutes(true, Optional.empty(), pubSubTopicRepository, requestHandler) + .getLeaderController(mockAdmin); LeaderControllerResponse leaderControllerResponseSsl = OBJECT_MAPPER.readValue( leaderControllerSslRoute.handle(request, mock(Response.class)).toString(), LeaderControllerResponse.class); assertEquals(leaderControllerResponseSsl.getCluster(), TEST_CLUSTER); assertEquals(leaderControllerResponseSsl.getUrl(), "https://" + TEST_HOST + ":" + TEST_SSL_PORT); assertEquals(leaderControllerResponseSsl.getSecureUrl(), "https://" + TEST_HOST + ":" + TEST_SSL_PORT); + assertEquals(leaderControllerResponse.getGrpcUrl(), TEST_HOST + ":" + TEST_GRPC_PORT); + assertEquals(leaderControllerResponse.getSecureGrpcUrl(), TEST_HOST + ":" + TEST_GRPC_SSL_PORT); // Controller doesn't support SSL - Instance leaderNonSslController = new Instance(TEST_NODE_ID, TEST_HOST, TEST_PORT, TEST_PORT); + Instance leaderNonSslController = + new Instance(TEST_NODE_ID, TEST_HOST, TEST_PORT, TEST_PORT, TEST_GRPC_PORT, TEST_GRPC_SSL_PORT); doReturn(leaderNonSslController).when(mockAdmin).getLeaderController(anyString()); LeaderControllerResponse leaderControllerNonSslResponse = OBJECT_MAPPER.readValue( @@ -78,11 +104,14 @@ public void testGetLeaderController() throws Exception { assertEquals(leaderControllerNonSslResponse.getCluster(), TEST_CLUSTER); assertEquals(leaderControllerNonSslResponse.getUrl(), "http://" + TEST_HOST + ":" + TEST_PORT); assertEquals(leaderControllerNonSslResponse.getSecureUrl(), null); + assertEquals(leaderControllerNonSslResponse.getGrpcUrl(), TEST_HOST + ":" + TEST_GRPC_PORT); + assertEquals(leaderControllerNonSslResponse.getSecureGrpcUrl(), TEST_HOST + ":" + TEST_GRPC_SSL_PORT); } @Test public void testGetAggregatedHealthStatus() throws Exception { - ControllerRoutes controllerRoutes = new ControllerRoutes(false, Optional.empty(), pubSubTopicRepository); + ControllerRoutes controllerRoutes = + new ControllerRoutes(false, Optional.empty(), pubSubTopicRepository, requestHandler); Admin mockAdmin = mock(VeniceParentHelixAdmin.class); List instanceList = Arrays.asList("instance1_5000", "instance2_5000"); diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/CreateStoreTest.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/CreateStoreTest.java index ebd3cd34d92..eee5acdbd07 100644 --- a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/CreateStoreTest.java +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/CreateStoreTest.java @@ -15,10 +15,13 @@ import com.linkedin.venice.HttpConstants; import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ControllerRequestHandlerDependencies; +import com.linkedin.venice.controller.VeniceParentHelixAdmin; import com.linkedin.venice.utils.Utils; import java.util.HashMap; import java.util.Optional; import org.apache.commons.httpclient.HttpStatus; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import spark.QueryParamsMap; import spark.Request; @@ -29,17 +32,27 @@ public class CreateStoreTest { private static String clusterName = Utils.getUniqueString("test-cluster"); + private VeniceControllerRequestHandler requestHandler; + private Admin mockAdmin; + + @BeforeMethod + public void setUp() { + mockAdmin = mock(VeniceParentHelixAdmin.class); + ControllerRequestHandlerDependencies dependencies = mock(ControllerRequestHandlerDependencies.class); + doReturn(mockAdmin).when(dependencies).getAdmin(); + requestHandler = new VeniceControllerRequestHandler(dependencies); + } + @Test public void testCreateStoreWhenThrowsNPEInternally() throws Exception { - Admin admin = mock(Admin.class); Request request = mock(Request.class); Response response = mock(Response.class); String fakeMessage = "fake_message"; - doReturn(true).when(admin).isLeaderControllerFor(clusterName); + doReturn(true).when(mockAdmin).isLeaderControllerFor(clusterName); // Throws NPE here - doThrow(new NullPointerException(fakeMessage)).when(admin) + doThrow(new NullPointerException(fakeMessage)).when(mockAdmin) .createStore(any(), any(), any(), any(), any(), anyBoolean(), any()); QueryParamsMap paramsMap = mock(QueryParamsMap.class); @@ -54,22 +67,21 @@ public void testCreateStoreWhenThrowsNPEInternally() throws Exception { doReturn("\"string\"").when(request).queryParams(VALUE_SCHEMA); CreateStore createStoreRoute = new CreateStore(false, Optional.empty()); - Route createStoreRouter = createStoreRoute.createStore(admin); + Route createStoreRouter = createStoreRoute.createStore(mockAdmin, requestHandler); createStoreRouter.handle(request, response); verify(response).status(HttpStatus.SC_INTERNAL_SERVER_ERROR); } @Test(expectedExceptions = Error.class) public void testCreateStoreWhenThrowsError() throws Exception { - Admin admin = mock(Admin.class); Request request = mock(Request.class); Response response = mock(Response.class); String fakeMessage = "fake_message"; - doReturn(true).when(admin).isLeaderControllerFor(clusterName); + doReturn(true).when(mockAdmin).isLeaderControllerFor(clusterName); // Throws NPE here - doThrow(new Error(fakeMessage)).when(admin).createStore(any(), any(), any(), any(), any(), anyBoolean(), any()); + doThrow(new Error(fakeMessage)).when(mockAdmin).createStore(any(), any(), any(), any(), any(), anyBoolean(), any()); QueryParamsMap paramsMap = mock(QueryParamsMap.class); doReturn(new HashMap<>()).when(paramsMap).toMap(); @@ -83,17 +95,16 @@ public void testCreateStoreWhenThrowsError() throws Exception { doReturn("\"string\"").when(request).queryParams(VALUE_SCHEMA); CreateStore createStoreRoute = new CreateStore(false, Optional.empty()); - Route createStoreRouter = createStoreRoute.createStore(admin); + Route createStoreRouter = createStoreRoute.createStore(mockAdmin, requestHandler); createStoreRouter.handle(request, response); } @Test public void testCreateStoreWhenSomeParamNotPresent() throws Exception { - Admin admin = mock(Admin.class); Request request = mock(Request.class); Response response = mock(Response.class); - doReturn(true).when(admin).isLeaderControllerFor(clusterName); + doReturn(true).when(mockAdmin).isLeaderControllerFor(clusterName); QueryParamsMap paramsMap = mock(QueryParamsMap.class); doReturn(new HashMap<>()).when(paramsMap).toMap(); @@ -103,18 +114,17 @@ public void testCreateStoreWhenSomeParamNotPresent() throws Exception { doReturn(clusterName).when(request).queryParams(CLUSTER); CreateStore createStoreRoute = new CreateStore(false, Optional.empty()); - Route createStoreRouter = createStoreRoute.createStore(admin); + Route createStoreRouter = createStoreRoute.createStore(mockAdmin, requestHandler); createStoreRouter.handle(request, response); verify(response).status(HttpStatus.SC_BAD_REQUEST); } @Test public void testCreateStoreWhenNotLeaderController() throws Exception { - Admin admin = mock(Admin.class); Request request = mock(Request.class); Response response = mock(Response.class); - doReturn(false).when(admin).isLeaderControllerFor(clusterName); + doReturn(false).when(mockAdmin).isLeaderControllerFor(clusterName); QueryParamsMap paramsMap = mock(QueryParamsMap.class); doReturn(new HashMap<>()).when(paramsMap).toMap(); @@ -128,7 +138,7 @@ public void testCreateStoreWhenNotLeaderController() throws Exception { doReturn("\"string\"").when(request).queryParams(VALUE_SCHEMA); CreateStore createStoreRoute = new CreateStore(false, Optional.empty()); - Route createStoreRouter = createStoreRoute.createStore(admin); + Route createStoreRouter = createStoreRoute.createStore(mockAdmin, requestHandler); createStoreRouter.handle(request, response); verify(response).status(HttpConstants.SC_MISDIRECTED_REQUEST); } diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/TestVeniceRouteHandler.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/TestVeniceRouteHandler.java index f01b0650572..690d2d8cd97 100644 --- a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/TestVeniceRouteHandler.java +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/TestVeniceRouteHandler.java @@ -1,15 +1,20 @@ package com.linkedin.venice.controller.server; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ControllerRequestHandlerDependencies; +import com.linkedin.venice.controller.VeniceParentHelixAdmin; import com.linkedin.venice.controllerapi.ControllerResponse; import com.linkedin.venice.exceptions.ErrorType; import com.linkedin.venice.exceptions.ExceptionType; import com.linkedin.venice.utils.ObjectMapperFactory; import org.apache.commons.httpclient.HttpStatus; import org.testng.Assert; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import spark.Request; import spark.Response; @@ -17,6 +22,15 @@ public class TestVeniceRouteHandler { + private Admin mockAdmin; + + @BeforeMethod + public void setUp() { + mockAdmin = mock(VeniceParentHelixAdmin.class); + ControllerRequestHandlerDependencies dependencies = mock(ControllerRequestHandlerDependencies.class); + doReturn(mockAdmin).when(dependencies).getAdmin(); + } + @Test public void testIsAllowListUser() throws Exception { Route userAllowedRoute = new VeniceRouteHandler(ControllerResponse.class) { diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/VeniceControllerGrpcServiceImplTest.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/VeniceControllerGrpcServiceImplTest.java new file mode 100644 index 00000000000..3017822cec1 --- /dev/null +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/VeniceControllerGrpcServiceImplTest.java @@ -0,0 +1,256 @@ +package com.linkedin.venice.controller.server; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +import com.linkedin.venice.controllerapi.D2ServiceDiscoveryResponse; +import com.linkedin.venice.controllerapi.LeaderControllerResponse; +import com.linkedin.venice.controllerapi.NewStoreResponse; +import com.linkedin.venice.controllerapi.request.ClusterDiscoveryRequest; +import com.linkedin.venice.controllerapi.request.ControllerRequest; +import com.linkedin.venice.controllerapi.request.CreateNewStoreRequest; +import com.linkedin.venice.controllerapi.transport.GrpcRequestResponseConverter; +import com.linkedin.venice.exceptions.VeniceException; +import com.linkedin.venice.protocols.controller.ClusterStoreGrpcInfo; +import com.linkedin.venice.protocols.controller.ControllerGrpcErrorType; +import com.linkedin.venice.protocols.controller.CreateStoreGrpcRequest; +import com.linkedin.venice.protocols.controller.CreateStoreGrpcResponse; +import com.linkedin.venice.protocols.controller.DiscoverClusterGrpcRequest; +import com.linkedin.venice.protocols.controller.DiscoverClusterGrpcResponse; +import com.linkedin.venice.protocols.controller.LeaderControllerGrpcRequest; +import com.linkedin.venice.protocols.controller.LeaderControllerGrpcResponse; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcErrorInfo; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc.VeniceControllerGrpcServiceBlockingStub; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + + +public class VeniceControllerGrpcServiceImplTest { + private static final String TEST_CLUSTER = "test-cluster"; + private static final String TEST_STORE = "test-store"; + private static final String D2_TEST_SERVICE = "d2://test-service"; + private static final String D2_TEST_SERVER = "d2://test-server"; + private static final String HTTP_URL = "http://localhost:8080"; + private static final String HTTPS_URL = "https://localhost:8081"; + private static final String GRPC_URL = "grpc://localhost:8082"; + private static final String SECURE_GRPC_URL = "grpcs://localhost:8083"; + private static final String OWNER = "test-owner"; + private static final String KEY_SCHEMA = "int"; + private static final String VALUE_SCHEMA = "string"; + + private Server grpcServer; + private ManagedChannel grpcChannel; + private VeniceControllerRequestHandler requestHandler; + private VeniceControllerGrpcServiceBlockingStub blockingStub; + + @BeforeMethod + public void setUp() throws Exception { + requestHandler = mock(VeniceControllerRequestHandler.class); + + // Create a unique server name for the in-process server + String serverName = InProcessServerBuilder.generateName(); + + // Start the gRPC server in-process + grpcServer = InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(new VeniceControllerGrpcServiceImpl(requestHandler)) + .build() + .start(); + + // Create a channel to communicate with the server + grpcChannel = InProcessChannelBuilder.forName(serverName).directExecutor().build(); + + // Create a blocking stub to make calls to the server + blockingStub = VeniceControllerGrpcServiceGrpc.newBlockingStub(grpcChannel); + } + + @AfterMethod + public void tearDown() throws Exception { + if (grpcServer != null) { + grpcServer.shutdown(); + } + if (grpcChannel != null) { + grpcChannel.shutdown(); + } + } + + @Test + public void testGetLeaderController() { + // Case 1: Successful response + doAnswer(invocation -> { + LeaderControllerResponse controllerResponse = invocation.getArgument(1); + controllerResponse.setCluster(TEST_CLUSTER); + controllerResponse.setUrl(HTTP_URL); + controllerResponse.setSecureUrl(HTTPS_URL); + controllerResponse.setGrpcUrl(GRPC_URL); + controllerResponse.setSecureGrpcUrl(SECURE_GRPC_URL); + return null; + }).when(requestHandler).getLeaderController(any(ControllerRequest.class), any(LeaderControllerResponse.class)); + + LeaderControllerGrpcRequest request = LeaderControllerGrpcRequest.newBuilder().setClusterName(TEST_CLUSTER).build(); + LeaderControllerGrpcResponse actualResponse = blockingStub.getLeaderController(request); + + assertNotNull(actualResponse, "Response should not be null"); + assertEquals(actualResponse.getClusterName(), TEST_CLUSTER, "Cluster name should match"); + assertEquals(actualResponse.getHttpUrl(), HTTP_URL, "HTTP URL should match"); + assertEquals(actualResponse.getHttpsUrl(), HTTPS_URL, "HTTPS URL should match"); + assertEquals(actualResponse.getGrpcUrl(), GRPC_URL, "gRPC URL should match"); + assertEquals(actualResponse.getSecureGrpcUrl(), SECURE_GRPC_URL, "Secure gRPC URL should match"); + + // Case 2: Bad request as cluster name is missing + LeaderControllerGrpcRequest requestWithoutClusterName = LeaderControllerGrpcRequest.newBuilder().build(); + StatusRuntimeException e = + expectThrows(StatusRuntimeException.class, () -> blockingStub.getLeaderController(requestWithoutClusterName)); + assertNotNull(e.getStatus(), "Status should not be null"); + assertEquals(e.getStatus().getCode(), Status.INVALID_ARGUMENT.getCode()); + + VeniceControllerGrpcErrorInfo errorInfo = GrpcRequestResponseConverter.parseControllerGrpcError(e); + assertNotNull(errorInfo, "Error info should not be null"); + assertFalse(errorInfo.hasStoreName(), "Store name should not be present in the error info"); + assertEquals(errorInfo.getErrorType(), ControllerGrpcErrorType.BAD_REQUEST); + assertTrue(errorInfo.getErrorMessage().contains("The request is missing the cluster_name")); + + // Case 3: requestHandler throws an exception + doAnswer(invocation -> { + throw new VeniceException("Failed to get leader controller"); + }).when(requestHandler).getLeaderController(any(ControllerRequest.class), any(LeaderControllerResponse.class)); + StatusRuntimeException e2 = + expectThrows(StatusRuntimeException.class, () -> blockingStub.getLeaderController(request)); + assertNotNull(e2.getStatus(), "Status should not be null"); + assertEquals(e2.getStatus().getCode(), Status.INTERNAL.getCode()); + VeniceControllerGrpcErrorInfo errorInfo2 = GrpcRequestResponseConverter.parseControllerGrpcError(e2); + assertNotNull(errorInfo2, "Error info should not be null"); + assertTrue(errorInfo2.hasClusterName(), "Cluster name should be present in the error info"); + assertEquals(errorInfo2.getClusterName(), TEST_CLUSTER); + assertFalse(errorInfo2.hasStoreName(), "Store name should not be present in the error info"); + assertEquals(errorInfo2.getErrorType(), ControllerGrpcErrorType.GENERAL_ERROR); + assertTrue(errorInfo2.getErrorMessage().contains("Failed to get leader controller")); + } + + @Test + public void testDiscoverClusterForStore() { + // Case 1: Successful response + doAnswer(invocation -> { + D2ServiceDiscoveryResponse response = invocation.getArgument(1); + response.setName(TEST_STORE); + response.setCluster(TEST_CLUSTER); + response.setD2Service(D2_TEST_SERVICE); + response.setServerD2Service(D2_TEST_SERVER); + return null; + }).when(requestHandler).discoverCluster(any(ClusterDiscoveryRequest.class), any(D2ServiceDiscoveryResponse.class)); + DiscoverClusterGrpcRequest request = DiscoverClusterGrpcRequest.newBuilder().setStoreName(TEST_STORE).build(); + DiscoverClusterGrpcResponse actualResponse = blockingStub.discoverClusterForStore(request); + assertNotNull(actualResponse, "Response should not be null"); + assertEquals(actualResponse.getStoreName(), TEST_STORE, "Store name should match"); + assertEquals(actualResponse.getClusterName(), TEST_CLUSTER, "Cluster name should match"); + assertEquals(actualResponse.getD2Service(), D2_TEST_SERVICE, "D2 service should match"); + assertEquals(actualResponse.getServerD2Service(), D2_TEST_SERVER, "Server D2 service should match"); + + // Case 2: Bad request as store name is missing + DiscoverClusterGrpcRequest requestWithoutStoreName = DiscoverClusterGrpcRequest.newBuilder().build(); + StatusRuntimeException e = + expectThrows(StatusRuntimeException.class, () -> blockingStub.discoverClusterForStore(requestWithoutStoreName)); + assertNotNull(e.getStatus(), "Status should not be null"); + assertEquals(e.getStatus().getCode(), Status.INVALID_ARGUMENT.getCode()); + VeniceControllerGrpcErrorInfo errorInfo = GrpcRequestResponseConverter.parseControllerGrpcError(e); + assertNotNull(errorInfo, "Error info should not be null"); + assertFalse(errorInfo.hasClusterName(), "Cluster name should not be present in the error info"); + assertEquals(errorInfo.getErrorType(), ControllerGrpcErrorType.BAD_REQUEST); + assertTrue(errorInfo.getErrorMessage().contains("The request is missing the store_name")); + + // Case 3: requestHandler throws an exception + doAnswer(invocation -> { + throw new VeniceException("Failed to discover cluster"); + }).when(requestHandler).discoverCluster(any(ClusterDiscoveryRequest.class), any(D2ServiceDiscoveryResponse.class)); + StatusRuntimeException e2 = + expectThrows(StatusRuntimeException.class, () -> blockingStub.discoverClusterForStore(request)); + assertNotNull(e2.getStatus(), "Status should not be null"); + assertEquals(e2.getStatus().getCode(), Status.INTERNAL.getCode()); + VeniceControllerGrpcErrorInfo errorInfo2 = GrpcRequestResponseConverter.parseControllerGrpcError(e2); + assertNotNull(errorInfo2, "Error info should not be null"); + assertFalse(errorInfo2.hasClusterName(), "Cluster name should not be present in the error info"); + assertEquals(errorInfo2.getErrorType(), ControllerGrpcErrorType.GENERAL_ERROR); + assertTrue(errorInfo2.getErrorMessage().contains("Failed to discover cluster")); + } + + @Test + public void testCreateStore() { + // Case 1: Successful response + doAnswer(invocation -> { + NewStoreResponse newStoreResponse = invocation.getArgument(1); + newStoreResponse.setCluster(TEST_CLUSTER); + newStoreResponse.setName(TEST_STORE); + newStoreResponse.setOwner(OWNER); + return null; + }).when(requestHandler).createStore(any(CreateNewStoreRequest.class), any(NewStoreResponse.class)); + CreateStoreGrpcRequest request = CreateStoreGrpcRequest.newBuilder() + .setClusterStoreInfo( + ClusterStoreGrpcInfo.newBuilder().setClusterName(TEST_CLUSTER).setStoreName(TEST_STORE).build()) + .setOwner(OWNER) + .setKeySchema(KEY_SCHEMA) + .setValueSchema(VALUE_SCHEMA) + .build(); + CreateStoreGrpcResponse actualResponse = blockingStub.createStore(request); + assertNotNull(actualResponse, "Response should not be null"); + assertNotNull(actualResponse.getClusterStoreInfo(), "ClusterStoreInfo should not be null"); + assertEquals(actualResponse.getClusterStoreInfo().getClusterName(), TEST_CLUSTER, "Cluster name should match"); + assertEquals(actualResponse.getClusterStoreInfo().getStoreName(), TEST_STORE, "Store name should match"); + + // Case 2: Bad request as cluster name is missing + CreateStoreGrpcRequest requestWithoutClusterName = CreateStoreGrpcRequest.newBuilder() + .setOwner(OWNER) + .setKeySchema(KEY_SCHEMA) + .setValueSchema(VALUE_SCHEMA) + .build(); + StatusRuntimeException e = + expectThrows(StatusRuntimeException.class, () -> blockingStub.createStore(requestWithoutClusterName)); + assertNotNull(e.getStatus(), "Status should not be null"); + assertEquals(e.getStatus().getCode(), Status.INVALID_ARGUMENT.getCode()); + VeniceControllerGrpcErrorInfo errorInfo = GrpcRequestResponseConverter.parseControllerGrpcError(e); + assertEquals(errorInfo.getErrorType(), ControllerGrpcErrorType.BAD_REQUEST); + assertNotNull(errorInfo, "Error info should not be null"); + assertTrue(errorInfo.getErrorMessage().contains("The request is missing the cluster_name")); + + // Case 3: Bad request as store name is missing + CreateStoreGrpcRequest requestWithoutStoreName = CreateStoreGrpcRequest.newBuilder() + .setClusterStoreInfo(ClusterStoreGrpcInfo.newBuilder().setClusterName(TEST_CLUSTER).build()) + .setOwner(OWNER) + .setKeySchema(KEY_SCHEMA) + .setValueSchema(VALUE_SCHEMA) + .build(); + StatusRuntimeException e2 = + expectThrows(StatusRuntimeException.class, () -> blockingStub.createStore(requestWithoutStoreName)); + assertNotNull(e2.getStatus(), "Status should not be null"); + assertEquals(e2.getStatus().getCode(), Status.INVALID_ARGUMENT.getCode()); + VeniceControllerGrpcErrorInfo errorInfo2 = GrpcRequestResponseConverter.parseControllerGrpcError(e2); + assertEquals(errorInfo2.getErrorType(), ControllerGrpcErrorType.BAD_REQUEST); + assertNotNull(errorInfo2, "Error info should not be null"); + assertTrue(errorInfo2.getErrorMessage().contains("The request is missing the store_name")); + + // Case 4: requestHandler throws an exception + doAnswer(invocation -> { + throw new VeniceException("Failed to create store"); + }).when(requestHandler).createStore(any(CreateNewStoreRequest.class), any(NewStoreResponse.class)); + StatusRuntimeException e3 = expectThrows(StatusRuntimeException.class, () -> blockingStub.createStore(request)); + assertNotNull(e3.getStatus(), "Status should not be null"); + assertEquals(e3.getStatus().getCode(), Status.INTERNAL.getCode()); + VeniceControllerGrpcErrorInfo errorInfo3 = GrpcRequestResponseConverter.parseControllerGrpcError(e3); + assertNotNull(errorInfo3, "Error info should not be null"); + assertEquals(errorInfo3.getErrorType(), ControllerGrpcErrorType.GENERAL_ERROR); + assertTrue(errorInfo3.getErrorMessage().contains("Failed to create store")); + } +} diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/VeniceControllerRequestHandlerTest.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/VeniceControllerRequestHandlerTest.java new file mode 100644 index 00000000000..7e0fdde7904 --- /dev/null +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/VeniceControllerRequestHandlerTest.java @@ -0,0 +1,154 @@ +package com.linkedin.venice.controller.server; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ControllerRequestHandlerDependencies; +import com.linkedin.venice.controllerapi.D2ServiceDiscoveryResponse; +import com.linkedin.venice.controllerapi.LeaderControllerResponse; +import com.linkedin.venice.controllerapi.NewStoreResponse; +import com.linkedin.venice.controllerapi.request.ClusterDiscoveryRequest; +import com.linkedin.venice.controllerapi.request.ControllerRequest; +import com.linkedin.venice.controllerapi.request.CreateNewStoreRequest; +import com.linkedin.venice.meta.Instance; +import com.linkedin.venice.utils.Pair; +import java.util.Optional; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + + +public class VeniceControllerRequestHandlerTest { + private VeniceControllerRequestHandler requestHandler; + private Admin admin; + private ControllerRequestHandlerDependencies dependencies; + + @BeforeMethod + public void setUp() { + admin = mock(Admin.class); + dependencies = mock(ControllerRequestHandlerDependencies.class); + when(dependencies.getAdmin()).thenReturn(admin); + when(dependencies.isSslEnabled()).thenReturn(true); + requestHandler = new VeniceControllerRequestHandler(dependencies); + } + + @Test + public void testGetLeaderController() { + ControllerRequest request = mock(ControllerRequest.class); + LeaderControllerResponse response = new LeaderControllerResponse(); + Instance leaderInstance = mock(Instance.class); + + when(request.getClusterName()).thenReturn("testCluster"); + when(admin.getLeaderController("testCluster")).thenReturn(leaderInstance); + when(leaderInstance.getUrl(true)).thenReturn("https://leader-url:443"); + when(leaderInstance.getUrl(false)).thenReturn("http://leader-url:80"); + when(leaderInstance.getGrpcUrl()).thenReturn("leader-grpc-url:50051"); + when(leaderInstance.getGrpcSslUrl()).thenReturn("leader-grpc-url:50052"); + when(leaderInstance.getPort()).thenReturn(80); + when(leaderInstance.getSslPort()).thenReturn(443); // SSL enabled + when(leaderInstance.getGrpcPort()).thenReturn(50051); + when(leaderInstance.getGrpcSslPort()).thenReturn(50052); + + requestHandler.getLeaderController(request, response); + + assertEquals(response.getCluster(), "testCluster"); + assertEquals(response.getUrl(), "https://leader-url:443"); // SSL enabled + assertEquals(response.getSecureUrl(), "https://leader-url:443"); + assertEquals(response.getGrpcUrl(), "leader-grpc-url:50051"); + assertEquals(response.getSecureGrpcUrl(), "leader-grpc-url:50052"); + + // SSL not enabled + when(dependencies.isSslEnabled()).thenReturn(false); + requestHandler = new VeniceControllerRequestHandler(dependencies); + LeaderControllerResponse response1 = new LeaderControllerResponse(); + requestHandler.getLeaderController(request, response1); + assertEquals(response1.getUrl(), "http://leader-url:80"); + assertEquals(response1.getSecureUrl(), "https://leader-url:443"); + assertEquals(response1.getGrpcUrl(), "leader-grpc-url:50051"); + assertEquals(response1.getSecureGrpcUrl(), "leader-grpc-url:50052"); + } + + @Test + public void testDiscoverCluster() { + ClusterDiscoveryRequest request = mock(ClusterDiscoveryRequest.class); + D2ServiceDiscoveryResponse response = new D2ServiceDiscoveryResponse(); + Pair clusterToD2Pair = Pair.create("testCluster", "testD2Service"); + + when(request.getStoreName()).thenReturn("testStore"); + when(admin.discoverCluster("testStore")).thenReturn(clusterToD2Pair); + when(admin.getServerD2Service("testCluster")).thenReturn("testServerD2Service"); + + requestHandler.discoverCluster(request, response); + + assertEquals(response.getName(), "testStore"); + assertEquals(response.getCluster(), "testCluster"); + assertEquals(response.getD2Service(), "testD2Service"); + assertEquals(response.getServerD2Service(), "testServerD2Service"); + } + + @Test + public void testCreateStore() { + CreateNewStoreRequest request = mock(CreateNewStoreRequest.class); + NewStoreResponse response = new NewStoreResponse(); + + when(request.getClusterName()).thenReturn("testCluster"); + when(request.getStoreName()).thenReturn("testStore"); + when(request.getKeySchema()).thenReturn("testKeySchema"); + when(request.getValueSchema()).thenReturn("testValueSchema"); + when(request.getOwner()).thenReturn("testOwner"); + when(request.getAccessPermissions()).thenReturn("testAccessPermissions"); + when(request.isSystemStore()).thenReturn(false); + + requestHandler.createStore(request, response); + + verify(admin, times(1)).createStore( + "testCluster", + "testStore", + "testOwner", + "testKeySchema", + "testValueSchema", + false, + Optional.of("testAccessPermissions")); + assertEquals(response.getCluster(), "testCluster"); + assertEquals(response.getName(), "testStore"); + assertEquals(response.getOwner(), "testOwner"); + } + + @Test + public void testCreateStoreWithNullAccessPermissions() { + CreateNewStoreRequest request = mock(CreateNewStoreRequest.class); + NewStoreResponse response = new NewStoreResponse(); + + when(request.getClusterName()).thenReturn("testCluster"); + when(request.getStoreName()).thenReturn("testStore"); + when(request.getKeySchema()).thenReturn("testKeySchema"); + when(request.getValueSchema()).thenReturn("testValueSchema"); + when(request.getOwner()).thenReturn("testOwner"); + when(request.getAccessPermissions()).thenReturn(null); + when(request.isSystemStore()).thenReturn(true); + + requestHandler.createStore(request, response); + + verify(admin, times(1)).createStore( + "testCluster", + "testStore", + "testOwner", + "testKeySchema", + "testValueSchema", + true, + Optional.empty()); + assertEquals(response.getCluster(), "testCluster"); + assertEquals(response.getName(), "testStore"); + assertEquals(response.getOwner(), "testOwner"); + } + + @Test + public void testIsSslEnabled() { + boolean sslEnabled = requestHandler.isSslEnabled(); + assertTrue(sslEnabled); + } +} diff --git a/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/grpc/ParentControllerRegionValidationInterceptorTest.java b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/grpc/ParentControllerRegionValidationInterceptorTest.java new file mode 100644 index 00000000000..903be8bf835 --- /dev/null +++ b/services/venice-controller/src/test/java/com/linkedin/venice/controller/server/grpc/ParentControllerRegionValidationInterceptorTest.java @@ -0,0 +1,97 @@ +package com.linkedin.venice.controller.server.grpc; + +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import com.linkedin.venice.controller.Admin; +import com.linkedin.venice.controller.ParentControllerRegionState; +import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc; +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import org.mockito.ArgumentCaptor; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + + +public class ParentControllerRegionValidationInterceptorTest { + private ParentControllerRegionValidationInterceptor interceptor; + private Admin admin; + private ServerCall call; + private Metadata headers; + private ServerCallHandler next; + + @BeforeMethod + public void setUp() { + admin = mock(Admin.class); + call = mock(ServerCall.class, RETURNS_DEEP_STUBS); + headers = new Metadata(); + next = mock(ServerCallHandler.class); + interceptor = new ParentControllerRegionValidationInterceptor(admin); + } + + @Test + public void testActiveParentControllerPasses() { + when(admin.isParent()).thenReturn(true); + when(admin.getParentControllerRegionState()).thenReturn(ParentControllerRegionState.ACTIVE); + + interceptor.interceptCall(call, headers, next); + + verify(next, times(1)).startCall(call, headers); + verify(call, never()).close(any(), any()); + } + + @Test + public void testInactiveParentControllerRejectsRequest() { + when(admin.isParent()).thenReturn(true); + when(admin.getParentControllerRegionState()).thenReturn(ParentControllerRegionState.PASSIVE); + + MethodDescriptor methodDescriptor = VeniceControllerGrpcServiceGrpc.getGetLeaderControllerMethod(); + when(call.getMethodDescriptor()).thenReturn(methodDescriptor); + when(call.getAttributes()).thenReturn(Attributes.EMPTY); + + interceptor.interceptCall(call, headers, next); + + verify(call, times(1)).close(any(io.grpc.Status.class), any(Metadata.class)); + verify(next, never()).startCall(call, headers); + } + + @Test + public void testNonParentControllerPasses() { + when(admin.isParent()).thenReturn(false); + + interceptor.interceptCall(call, headers, next); + + verify(next, times(1)).startCall(call, headers); + verify(call, never()).close(any(), any()); + } + + @Test + public void testErrorMessageAndCodeOnRejection() { + when(admin.isParent()).thenReturn(true); + when(admin.getParentControllerRegionState()).thenReturn(ParentControllerRegionState.PASSIVE); + MethodDescriptor methodDescriptor = VeniceControllerGrpcServiceGrpc.getGetLeaderControllerMethod(); + when(call.getMethodDescriptor()).thenReturn(methodDescriptor); + when(call.getAttributes()).thenReturn(Attributes.EMPTY); + + interceptor.interceptCall(call, headers, next); + + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(io.grpc.Status.class); + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(call, times(1)).close(statusCaptor.capture(), metadataCaptor.capture()); + verify(next, never()).startCall(call, headers); + + io.grpc.Status status = statusCaptor.getValue(); + assertTrue(status.getDescription().contains("Parent controller is not active")); + assertEquals(status.getCode(), io.grpc.Status.FAILED_PRECONDITION.getCode()); + } +} diff --git a/services/venice-server/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerConfigTest.java b/services/venice-server/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerConfigTest.java deleted file mode 100644 index e99d681cd84..00000000000 --- a/services/venice-server/src/test/java/com/linkedin/venice/grpc/VeniceGrpcServerConfigTest.java +++ /dev/null @@ -1,86 +0,0 @@ -package com.linkedin.venice.grpc; - -import static org.mockito.Mockito.mock; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertNotNull; -import static org.testng.Assert.assertNull; - -import com.linkedin.alpini.base.concurrency.ExecutorService; -import com.linkedin.venice.security.SSLFactory; -import io.grpc.BindableService; -import io.grpc.ServerCredentials; -import io.grpc.ServerInterceptor; -import java.util.concurrent.Executor; -import org.testng.annotations.Test; - - -public class VeniceGrpcServerConfigTest { - @Test - public void testDefaults() { - VeniceGrpcServerConfig config = new VeniceGrpcServerConfig.Builder().setPort(8080) - .setService(mock(BindableService.class)) - .setNumThreads(10) - .build(); - - assertEquals(config.getPort(), 8080); - assertNull(config.getCredentials()); - assertEquals(config.getInterceptors().size(), 0); - } - - @Test - public void testCustomCredentials() { - VeniceGrpcServerConfig config = new VeniceGrpcServerConfig.Builder().setPort(8080) - .setService(mock(BindableService.class)) - .setCredentials(mock(ServerCredentials.class)) - .setExecutor(mock(ExecutorService.class)) - .setNumThreads(10) - .build(); - - assertNotNull(config.getCredentials()); - assertEquals(config.getCredentials(), config.getCredentials()); - } - - @Test - public void testInterceptor() { - ServerInterceptor interceptor = mock(ServerInterceptor.class); - VeniceGrpcServerConfig config = new VeniceGrpcServerConfig.Builder().setPort(8080) - .setService(mock(BindableService.class)) - .setInterceptor(interceptor) - .setNumThreads(10) - .build(); - - assertEquals(config.getInterceptors().size(), 1); - assertEquals(config.getInterceptors().get(0), interceptor); - } - - @Test - public void testSSLFactory() { - SSLFactory sslFactory = mock(SSLFactory.class); - VeniceGrpcServerConfig config = new VeniceGrpcServerConfig.Builder().setPort(8080) - .setService(mock(BindableService.class)) - .setSslFactory(sslFactory) - .setNumThreads(10) - .build(); - - assertEquals(config.getSslFactory(), sslFactory); - } - - @Test - public void testNumThreadsAndExecutor() { - VeniceGrpcServerConfig.Builder configBuilder = - new VeniceGrpcServerConfig.Builder().setPort(1010).setService(mock(BindableService.class)).setNumThreads(10); - - VeniceGrpcServerConfig testExectorCreation = configBuilder.build(); - - Executor exec = testExectorCreation.getExecutor(); - assertNotNull(exec); - - VeniceGrpcServerConfig testCustomExecutor = configBuilder.setExecutor(exec).build(); - assertEquals(testCustomExecutor.getExecutor(), exec); - } - - @Test(expectedExceptions = IllegalArgumentException.class) - public void testNoService() { - new VeniceGrpcServerConfig.Builder().setPort(8080).build(); - } -}