diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 41d57b3205..d16e04c90c 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -67,7 +67,7 @@ jobs: ZWE_CONFIGS_APIML_SERVICE_ADDITIONALREGISTRATION_0_DISCOVERYSERVICEURLS: https://discovery-service-2:10011/eureka SERVER_MAX_HTTP_REQUEST_HEADER_SIZE: 16348 SERVER_WEBSOCKET_REQUESTBUFFERSIZE: 16348 - APIML_GATEWAY_ROUTING_SERVICESTOLIMITREQUESTRATE: discoverableclient + APIML_GATEWAY_SERVICESTOLIMITREQUESTRATE: discoverableclient APIML_GATEWAY_ROUTING_COOKIENAMEFORRATELIMIT: apimlAuthenticationToken zaas-service: image: ghcr.io/balhar-jakub/zaas-service:${{ github.run_id }}-${{ github.run_number }} diff --git a/common-service-core/src/main/java/org/zowe/apiml/constants/EurekaMetadataDefinition.java b/common-service-core/src/main/java/org/zowe/apiml/constants/EurekaMetadataDefinition.java index 66dfb69965..6fb841c9fc 100644 --- a/common-service-core/src/main/java/org/zowe/apiml/constants/EurekaMetadataDefinition.java +++ b/common-service-core/src/main/java/org/zowe/apiml/constants/EurekaMetadataDefinition.java @@ -41,6 +41,7 @@ private EurekaMetadataDefinition() { public static final String SERVICE_EXTERNAL_URL = "apiml.service.externalUrl"; public static final String SERVICE_SUPPORTING_CLIENT_CERT_FORWARDING = "apiml.service.supportClientCertForwarding"; public static final String ENABLE_URL_ENCODED_CHARACTERS = "apiml.enableUrlEncodedCharacters"; + public static final String APPLY_RATE_LIMITER_FILTER = "apiml.gateway.applyRateLimiterFilter"; public static final String APIML_ID = "apiml.service.apimlId"; public static final String REGISTRATION_TYPE = "apiml.registrationType"; diff --git a/gateway-package/src/main/resources/bin/start.sh b/gateway-package/src/main/resources/bin/start.sh index 8f70f2e601..65b425cb09 100755 --- a/gateway-package/src/main/resources/bin/start.sh +++ b/gateway-package/src/main/resources/bin/start.sh @@ -318,10 +318,10 @@ _BPX_JOBNAME=${ZWE_zowe_job_prefix}${GATEWAY_CODE} ${JAVA_BIN_DIR}java \ -Dapiml.gateway.cachePeriodSec=${ZWE_configs_apiml_gateway_registry_cachePeriodSec:-120} \ -Dapiml.gateway.registry.enabled=${ZWE_configs_apiml_gateway_registry_enabled:-false} \ -Dapiml.gateway.maxSimultaneousRequests=${ZWE_configs_gateway_registry_maxSimultaneousRequests:-20} \ - -Dapiml.gateway.rateLimiterCapacity=${ZWE_configs_apiml_gateway_routing_rateLimiterCapacity:-20} \ - -Dapiml.gateway.rateLimiterTokens=${ZWE_configs_apiml_gateway_routing_rateLimiterTokens:-20} \ - -Dapiml.gateway.rateLimiterRefillDuration=${ZWE_configs_apiml_gateway_routing_rateLimiterRefillDuration:-1} \ - -Dapiml.gateway.servicesToLimitRequestRate=${ZWE_configs_apiml_gateway_routing_servicesToLimitRequestRate:-} \ + -Dapiml.gateway.rateLimiterCapacity=${ZWE_configs_apiml_gateway_rateLimiterCapacity:-20} \ + -Dapiml.gateway.rateLimiterTokens=${ZWE_configs_apiml_gateway_rateLimiterTokens:-20} \ + -Dapiml.gateway.rateLimiterRefillDuration=${ZWE_configs_apiml_gateway_rateLimiterRefillDuration:-1} \ + -Dapiml.gateway.servicesToLimitRequestRate=${ZWE_configs_apiml_gateway_servicesToLimitRequestRate:-} \ -Dapiml.gateway.cookieNameForRateLimit=${cookieName:-apimlAuthenticationToken} \ -Dapiml.gateway.registry.metadata-key-allow-list=${ZWE_configs_gateway_registry_metadataKeyAllowList:-} \ -Dapiml.gateway.refresh-interval-ms=${ZWE_configs_gateway_registry_refreshIntervalMs:-30000} \ diff --git a/gateway-service/src/main/java/org/zowe/apiml/gateway/config/RoutingConfig.java b/gateway-service/src/main/java/org/zowe/apiml/gateway/config/RoutingConfig.java index ffd00835a6..ca19ce5598 100644 --- a/gateway-service/src/main/java/org/zowe/apiml/gateway/config/RoutingConfig.java +++ b/gateway-service/src/main/java/org/zowe/apiml/gateway/config/RoutingConfig.java @@ -62,10 +62,6 @@ public List filters() { retryFilter.addArg("series", ""); filters.add(retryFilter); - FilterDefinition rateLimiterFilter = new FilterDefinition(); - rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory"); - filters.add(rateLimiterFilter); - for (String headerName : ignoredHeadersWhenCorsEnabled.split(",")) { FilterDefinition removeHeaders = new FilterDefinition(); removeHeaders.setName("RemoveRequestHeader"); diff --git a/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiter.java b/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiter.java index 90ec956490..c24c7d292b 100644 --- a/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiter.java +++ b/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiter.java @@ -27,12 +27,15 @@ public class InMemoryRateLimiter implements RateLimiter { private final Map cache = new ConcurrentHashMap<>(); - @Value("${apiml.gateway.routing.rateLimiterCapacity:20}") + + @Value("${apiml.gateway.rateLimiterCapacity:20}") int capacity; - @Value("${apiml.gateway.routing.rateLimiterTokens:20}") + + @Value("${apiml.gateway.rateLimiterTokens:20}") int tokens; - @Value("${apiml.gateway.routing.rateLimiterRefillDuration:1}") - Integer refillDuration; + + @Value("${apiml.gateway.rateLimiterRefillDuration:1}") + int refillDuration; @Override public Mono isAllowed(String routeId, String id) { @@ -55,6 +58,12 @@ private Map getHeaders(Bucket bucket) { return headers; } + public void setParameters(int capacity, int tokens, int refillDuration) { + this.capacity = (capacity != 0) ? capacity : this.capacity; + this.tokens = (tokens != 0) ? tokens : this.tokens; + this.refillDuration = (refillDuration != 0) ? refillDuration : this.refillDuration;; + } + @Override public Map getConfig() { Config defaultConfig = new Config(); diff --git a/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactory.java b/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactory.java index c6bddfe298..7ec77b30a6 100644 --- a/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactory.java +++ b/gateway-service/src/main/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactory.java @@ -14,7 +14,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import lombok.Getter; import lombok.Setter; -import org.springframework.beans.factory.annotation.Value; import org.springframework.cloud.gateway.filter.GatewayFilter; import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory; import org.springframework.http.HttpStatus; @@ -34,13 +33,10 @@ public class InMemoryRateLimiterFilterFactory extends AbstractGatewayFilterFacto @InjectApimlLogger private ApimlLogger apimlLog = ApimlLogger.empty(); - private final InMemoryRateLimiter rateLimiter; + private InMemoryRateLimiter rateLimiter; private final KeyResolver keyResolver; - @Value("${apiml.gateway.routing.servicesToLimitRequestRate:-}") - List serviceIds; - private final ObjectMapper mapper; private final MessageService messageService; @@ -55,17 +51,18 @@ public InMemoryRateLimiterFilterFactory(InMemoryRateLimiter rateLimiter, KeyReso @Override public GatewayFilter apply(Config config) { + this.rateLimiter.setParameters(config.capacity, config.tokens, config.refillDuration); return (exchange, chain) -> { List pathElements = exchange.getRequest().getPath().elements(); String requestPath = (!pathElements.isEmpty() && pathElements.size() > 1) ? pathElements.get(1).value() : null; - if (requestPath == null || !serviceIds.contains(requestPath)) { + if (requestPath == null) { return chain.filter(exchange); } return keyResolver.resolve(exchange).flatMap(key -> { if (key.isEmpty()) { return chain.filter(exchange); } - return rateLimiter.isAllowed(config.getRouteId(), key).flatMap(response -> { + return rateLimiter.isAllowed(requestPath, key).flatMap(response -> { if (response.isAllowed()) { return chain.filter(exchange); } else { @@ -87,9 +84,8 @@ public GatewayFilter apply(Config config) { @Getter @Setter public static class Config { - private String routeId; - private Integer capacity; - private Integer tokens; - private Integer refillIntervalSeconds; + private int capacity; + private int tokens; + private int refillDuration; } } diff --git a/gateway-service/src/main/java/org/zowe/apiml/gateway/service/RouteLocator.java b/gateway-service/src/main/java/org/zowe/apiml/gateway/service/RouteLocator.java index bbb51d53cb..102083dc49 100644 --- a/gateway-service/src/main/java/org/zowe/apiml/gateway/service/RouteLocator.java +++ b/gateway-service/src/main/java/org/zowe/apiml/gateway/service/RouteLocator.java @@ -49,6 +49,9 @@ public class RouteLocator implements RouteDefinitionLocator { @Value("${apiml.service.forwardClientCertEnabled:false}") private boolean forwardingClientCertEnabled; + @Value("${apiml.gateway.servicesToLimitRequestRate:-}") + List servicesToLimitRequestRate; + private final ApplicationContext context; private final CorsUtils corsUtils; @@ -140,6 +143,21 @@ List getPostRoutingFilters(ServiceInstance serviceInstance) { serviceRelated.add(forbidEncodedCharactersFilter); } + if (Optional.ofNullable(serviceInstance.getMetadata().get(APPLY_RATE_LIMITER_FILTER)) + .map(Boolean::parseBoolean) + .orElse(false)) { + FilterDefinition rateLimiterFilter = new FilterDefinition(); + rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory"); + rateLimiterFilter.addArg("capacity", serviceInstance.getMetadata().get("apiml.gateway.rateLimiterCapacity")); + rateLimiterFilter.addArg("tokens", serviceInstance.getMetadata().get("apiml.gateway.rateLimiterTokens")); + rateLimiterFilter.addArg("refillDuration", serviceInstance.getMetadata().get("apiml.gateway.rateLimiterRefillDuration")); + serviceRelated.add(rateLimiterFilter); + } else if (servicesToLimitRequestRate != null && servicesToLimitRequestRate.contains(serviceInstance.getServiceId().toLowerCase())) { + FilterDefinition rateLimiterFilter = new FilterDefinition(); + rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory"); + serviceRelated.add(rateLimiterFilter); + } + FilterDefinition pageRedirectionFilter = new FilterDefinition(); pageRedirectionFilter.setName("PageRedirectionFilterFactory"); pageRedirectionFilter.addArg("serviceId", serviceInstance.getServiceId()); diff --git a/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactoryTest.java b/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactoryTest.java index cd39912e79..fff902bb6e 100644 --- a/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactoryTest.java +++ b/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterFilterFactoryTest.java @@ -27,7 +27,6 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -40,7 +39,6 @@ public class InMemoryRateLimiterFilterFactoryTest { private InMemoryRateLimiterFilterFactory filterFactory; private ServerWebExchange exchange; private GatewayFilterChain chain; - private String serviceId; private MockServerHttpRequest request; private InMemoryRateLimiterFilterFactory.Config config; private MessageService messageService; @@ -49,19 +47,16 @@ public class InMemoryRateLimiterFilterFactoryTest { @BeforeEach public void setUp() { - serviceId = "testService"; rateLimiter = mock(InMemoryRateLimiter.class); keyResolver = mock(KeyResolver.class); messageService = mock(MessageService.class); message = mock(Message.class); objectMapper = mock(ObjectMapper.class); filterFactory = new InMemoryRateLimiterFilterFactory(rateLimiter, keyResolver,objectMapper, messageService); - filterFactory.serviceIds = List.of(serviceId); - request = MockServerHttpRequest.get("/" + serviceId).build(); + request = MockServerHttpRequest.get("/" + "serviceId").build(); exchange = MockServerWebExchange.from(request); chain = mock(GatewayFilterChain.class); config = mock(InMemoryRateLimiterFilterFactory.Config.class); - when(config.getRouteId()).thenReturn("testRoute"); } @Test @@ -115,20 +110,6 @@ public void apply_shouldAllowRequest_whenKeyIsNull() { verify(chain, times(1)).filter(exchange); } - @Test - public void apply_shouldAllowRequest_whenServiceIdDoesNotMatch() { - String nonMatchingId = "nonMatchingId"; - when(keyResolver.resolve(exchange)).thenReturn(Mono.just("testKey")); - request = MockServerHttpRequest.get("/" + nonMatchingId).build(); - exchange = MockServerWebExchange.from(request); - when(chain.filter(any(ServerWebExchange.class))).thenReturn(Mono.empty()); - - StepVerifier.create(filterFactory.apply(config).filter(exchange, chain)) - .expectComplete() - .verify(); - verify(chain, times(1)).filter(exchange); - } - @Test public void apply_shouldAllowRequest_whenServiceIdEmpty() { when(keyResolver.resolve(exchange)).thenReturn(Mono.just("testKey")); diff --git a/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterTest.java b/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterTest.java index 65ebfa3848..18b1d2e2d3 100644 --- a/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterTest.java +++ b/gateway-service/src/test/java/org/zowe/apiml/gateway/filters/InMemoryRateLimiterTest.java @@ -88,4 +88,23 @@ public void testNewConfig() { assertEquals(rateLimiter.tokens, config.getTokens(), "Config tokens should match the rate limiter tokens"); assertEquals(rateLimiter.refillDuration, config.getRefillDuration(), "Config refill duration should match the rate limiter refill duration"); } + + @Test + public void setNonNullParametersTest() { + Integer newCapacity = 20; + Integer newTokens = 20; + Integer newRefillDuration = 2; + rateLimiter.setParameters(newCapacity, newTokens, newRefillDuration); + assertEquals(newCapacity, rateLimiter.capacity); + assertEquals(newTokens, rateLimiter.tokens); + assertEquals(newRefillDuration, rateLimiter.refillDuration); + } + + @Test + public void setParametersWithNullValuesTest() { + Integer newCapacity = 30; + rateLimiter.setParameters(newCapacity, 0, 0); + assertEquals(newCapacity, rateLimiter.capacity); + + } } diff --git a/gateway-service/src/test/java/org/zowe/apiml/gateway/service/RouteLocatorTest.java b/gateway-service/src/test/java/org/zowe/apiml/gateway/service/RouteLocatorTest.java index 9a4736cc77..e4b45ef9aa 100644 --- a/gateway-service/src/test/java/org/zowe/apiml/gateway/service/RouteLocatorTest.java +++ b/gateway-service/src/test/java/org/zowe/apiml/gateway/service/RouteLocatorTest.java @@ -265,7 +265,7 @@ class PostRoutingFilterDefinition { private final List COMMON_FILTERS = Collections.singletonList(mock(FilterDefinition.class)); private final RouteLocator routeLocator = new RouteLocator(null, null, null, COMMON_FILTERS, Collections.emptyList(), null); - private ServiceInstance createServiceInstance(Boolean forwardingEnabled, Boolean encodedCharactersEnabled) { + private ServiceInstance createServiceInstance(Boolean forwardingEnabled, Boolean encodedCharactersEnabled, Boolean rateLimiterEnabled) { Map metadata = new HashMap<>(); if (forwardingEnabled != null) { metadata.put(SERVICE_SUPPORTING_CLIENT_CERT_FORWARDING, String.valueOf(forwardingEnabled)); @@ -273,6 +273,9 @@ private ServiceInstance createServiceInstance(Boolean forwardingEnabled, Boolean if (encodedCharactersEnabled != null) { metadata.put(ENABLE_URL_ENCODED_CHARACTERS, String.valueOf(encodedCharactersEnabled)); } + if (rateLimiterEnabled != null) { + metadata.put(APPLY_RATE_LIMITER_FILTER, String.valueOf(rateLimiterEnabled)); + } ServiceInstance serviceInstance = mock(ServiceInstance.class); doReturn(metadata).when(serviceInstance).getMetadata(); return serviceInstance; @@ -290,7 +293,7 @@ void enableForwarding() { @Test void givenServiceAllowingCertForwarding_whenGetPostRoutingFilters_thenAddClientCertFilterFactory() { - ServiceInstance serviceInstance = createServiceInstance(Boolean.TRUE, null); + ServiceInstance serviceInstance = createServiceInstance(Boolean.TRUE, null, null); List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); assertEquals(3, filterDefinitions.size()); // common filters + PageRedirectionFilterFactory @@ -299,7 +302,7 @@ void givenServiceAllowingCertForwarding_whenGetPostRoutingFilters_thenAddClientC @Test void givenServiceNotAllowingCertForwarding_whenGetPostRoutingFilters_thenReturnJustCommon() { - ServiceInstance serviceInstance = createServiceInstance(Boolean.FALSE, null); + ServiceInstance serviceInstance = createServiceInstance(Boolean.FALSE, null, null); List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); assertTrue(filterDefinitions.containsAll(COMMON_FILTERS), "Not all common filters are defined"); @@ -310,7 +313,7 @@ void givenServiceNotAllowingCertForwarding_whenGetPostRoutingFilters_thenReturnJ @Test void givenServiceWithoutCertForwardingConfig_whenGetPostRoutingFilters_thenReturnJustCommon() { - ServiceInstance serviceInstance = createServiceInstance(null, null); + ServiceInstance serviceInstance = createServiceInstance(null, null, null); List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); assertTrue(filterDefinitions.containsAll(COMMON_FILTERS), "Not all common filters are defined"); @@ -330,7 +333,7 @@ void disableForwarding() { @Test void givenAnyService_whenGetPostRoutingFilters_thenReturnJustCommon() { - ServiceInstance serviceInstance = createServiceInstance(Boolean.TRUE, null); + ServiceInstance serviceInstance = createServiceInstance(Boolean.TRUE, null, null); List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); assertTrue(filterDefinitions.containsAll(COMMON_FILTERS), "Not all common filters are defined"); @@ -345,7 +348,7 @@ class EncodedCharacters { @Test void givenServiceAllowingEncodedCharacters_whenGetPostRoutingFilters_thenReturnJustCommon() { - ServiceInstance serviceInstance = createServiceInstance(null, Boolean.TRUE); + ServiceInstance serviceInstance = createServiceInstance(null, Boolean.TRUE, null); List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); assertTrue(filterDefinitions.containsAll(COMMON_FILTERS), "Not all common filters are defined"); assertEquals(2, filterDefinitions.size()); @@ -354,7 +357,7 @@ void givenServiceAllowingEncodedCharacters_whenGetPostRoutingFilters_thenReturnJ @Test void givenServiceNotAllowingEncodedCharacters_whenGetPostRoutingFilters_thenAddEncodedCharacterFilterFactory() { - ServiceInstance serviceInstance = createServiceInstance(null, Boolean.FALSE); + ServiceInstance serviceInstance = createServiceInstance(null, Boolean.FALSE, null); List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); assertEquals(3, filterDefinitions.size()); assertEquals("ForbidEncodedCharactersFilterFactory", filterDefinitions.get(1).getName()); @@ -362,7 +365,7 @@ void givenServiceNotAllowingEncodedCharacters_whenGetPostRoutingFilters_thenAddE @Test void givenServiceWithoutAllowingEncodedCharacters_whenGetPostRoutingFilters_thenAddEncodedCharacterFilterFactory() { - ServiceInstance serviceInstance = createServiceInstance(null, null); + ServiceInstance serviceInstance = createServiceInstance(null, null, null); List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); assertTrue(filterDefinitions.containsAll(COMMON_FILTERS), "Not all common filters are defined"); assertEquals(2, filterDefinitions.size()); @@ -371,6 +374,37 @@ void givenServiceWithoutAllowingEncodedCharacters_whenGetPostRoutingFilters_then } + @Nested + class RateLimiter { + + @Test + void givenServiceNotAllowingRateLimiter_whenGetPostRoutingFilters_thenReturnJustCommon() { + ServiceInstance serviceInstance = createServiceInstance(null, null, Boolean.FALSE); + List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); + assertTrue(filterDefinitions.containsAll(COMMON_FILTERS), "Not all common filters are defined"); + assertEquals(2, filterDefinitions.size()); + assertTrue(filterDefinitions.stream().noneMatch(filter -> "InMemoryRateLimiterFilterFactory".equals(filter.getName()))); + } + + @Test + void givenServiceAllowingRateLimiter_whenGetPostRoutingFilters_thenAddInMemoryRateLimiterFilterFactory() { + ServiceInstance serviceInstance = createServiceInstance(null, null, Boolean.TRUE); + List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); + assertEquals(3, filterDefinitions.size()); + assertEquals("InMemoryRateLimiterFilterFactory", filterDefinitions.get(1).getName()); + } + + @Test + void givenServiceWithoutAllowingRateLimiter_whenGetPostRoutingFilters_thenDoNotAddInMemoryRateLimiterFilterFactory() { + ServiceInstance serviceInstance = createServiceInstance(null, null, null); + List filterDefinitions = routeLocator.getPostRoutingFilters(serviceInstance); + assertTrue(filterDefinitions.containsAll(COMMON_FILTERS), "Not all common filters are defined"); + assertEquals(2, filterDefinitions.size()); + assertTrue(filterDefinitions.stream().noneMatch(filter -> "InMemoryRateLimiterFilterFactory".equals(filter.getName()))); + } + + } + } }