Skip to content

Commit

Permalink
feat: rate limit per service (#3903)
Browse files Browse the repository at this point in the history
* configure rate limiter per service
  • Loading branch information
kishkinova authored Nov 19, 2024
1 parent d76cdf7 commit cad63cb
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
ZWE_CONFIGS_APIML_SERVICE_ADDITIONALREGISTRATION_0_DISCOVERYSERVICEURLS: https://discovery-service-2:10011/eureka
SERVER_MAX_HTTP_REQUEST_HEADER_SIZE: 16348
SERVER_WEBSOCKET_REQUESTBUFFERSIZE: 16348
APIML_GATEWAY_ROUTING_SERVICESTOLIMITREQUESTRATE: discoverableclient
APIML_GATEWAY_SERVICESTOLIMITREQUESTRATE: discoverableclient
APIML_GATEWAY_ROUTING_COOKIENAMEFORRATELIMIT: apimlAuthenticationToken
zaas-service:
image: ghcr.io/balhar-jakub/zaas-service:${{ github.run_id }}-${{ github.run_number }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private EurekaMetadataDefinition() {
public static final String SERVICE_EXTERNAL_URL = "apiml.service.externalUrl";
public static final String SERVICE_SUPPORTING_CLIENT_CERT_FORWARDING = "apiml.service.supportClientCertForwarding";
public static final String ENABLE_URL_ENCODED_CHARACTERS = "apiml.enableUrlEncodedCharacters";
public static final String APPLY_RATE_LIMITER_FILTER = "apiml.gateway.applyRateLimiterFilter";
public static final String APIML_ID = "apiml.service.apimlId";
public static final String REGISTRATION_TYPE = "apiml.registrationType";

Expand Down
8 changes: 4 additions & 4 deletions gateway-package/src/main/resources/bin/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ _BPX_JOBNAME=${ZWE_zowe_job_prefix}${GATEWAY_CODE} ${JAVA_BIN_DIR}java \
-Dapiml.gateway.cachePeriodSec=${ZWE_configs_apiml_gateway_registry_cachePeriodSec:-120} \
-Dapiml.gateway.registry.enabled=${ZWE_configs_apiml_gateway_registry_enabled:-false} \
-Dapiml.gateway.maxSimultaneousRequests=${ZWE_configs_gateway_registry_maxSimultaneousRequests:-20} \
-Dapiml.gateway.rateLimiterCapacity=${ZWE_configs_apiml_gateway_routing_rateLimiterCapacity:-20} \
-Dapiml.gateway.rateLimiterTokens=${ZWE_configs_apiml_gateway_routing_rateLimiterTokens:-20} \
-Dapiml.gateway.rateLimiterRefillDuration=${ZWE_configs_apiml_gateway_routing_rateLimiterRefillDuration:-1} \
-Dapiml.gateway.servicesToLimitRequestRate=${ZWE_configs_apiml_gateway_routing_servicesToLimitRequestRate:-} \
-Dapiml.gateway.rateLimiterCapacity=${ZWE_configs_apiml_gateway_rateLimiterCapacity:-20} \
-Dapiml.gateway.rateLimiterTokens=${ZWE_configs_apiml_gateway_rateLimiterTokens:-20} \
-Dapiml.gateway.rateLimiterRefillDuration=${ZWE_configs_apiml_gateway_rateLimiterRefillDuration:-1} \
-Dapiml.gateway.servicesToLimitRequestRate=${ZWE_configs_apiml_gateway_servicesToLimitRequestRate:-} \
-Dapiml.gateway.cookieNameForRateLimit=${cookieName:-apimlAuthenticationToken} \
-Dapiml.gateway.registry.metadata-key-allow-list=${ZWE_configs_gateway_registry_metadataKeyAllowList:-} \
-Dapiml.gateway.refresh-interval-ms=${ZWE_configs_gateway_registry_refreshIntervalMs:-30000} \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ public List<FilterDefinition> filters() {
retryFilter.addArg("series", "");
filters.add(retryFilter);

FilterDefinition rateLimiterFilter = new FilterDefinition();
rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory");
filters.add(rateLimiterFilter);

for (String headerName : ignoredHeadersWhenCorsEnabled.split(",")) {
FilterDefinition removeHeaders = new FilterDefinition();
removeHeaders.setName("RemoveRequestHeader");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
public class InMemoryRateLimiter implements RateLimiter<InMemoryRateLimiter.Config> {

private final Map<String, Bucket> cache = new ConcurrentHashMap<>();
@Value("${apiml.gateway.routing.rateLimiterCapacity:20}")

@Value("${apiml.gateway.rateLimiterCapacity:20}")
int capacity;
@Value("${apiml.gateway.routing.rateLimiterTokens:20}")

@Value("${apiml.gateway.rateLimiterTokens:20}")
int tokens;
@Value("${apiml.gateway.routing.rateLimiterRefillDuration:1}")
Integer refillDuration;

@Value("${apiml.gateway.rateLimiterRefillDuration:1}")
int refillDuration;

@Override
public Mono<Response> isAllowed(String routeId, String id) {
Expand All @@ -55,6 +58,12 @@ private Map<String, String> getHeaders(Bucket bucket) {
return headers;
}

public void setParameters(int capacity, int tokens, int refillDuration) {
this.capacity = (capacity != 0) ? capacity : this.capacity;
this.tokens = (tokens != 0) ? tokens : this.tokens;
this.refillDuration = (refillDuration != 0) ? refillDuration : this.refillDuration;;
}

@Override
public Map<String, Config> getConfig() {
Config defaultConfig = new Config();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.Setter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.http.HttpStatus;
Expand All @@ -34,13 +33,10 @@ public class InMemoryRateLimiterFilterFactory extends AbstractGatewayFilterFacto
@InjectApimlLogger
private ApimlLogger apimlLog = ApimlLogger.empty();

private final InMemoryRateLimiter rateLimiter;
private InMemoryRateLimiter rateLimiter;

private final KeyResolver keyResolver;

@Value("${apiml.gateway.routing.servicesToLimitRequestRate:-}")
List<String> serviceIds;

private final ObjectMapper mapper;

private final MessageService messageService;
Expand All @@ -55,17 +51,18 @@ public InMemoryRateLimiterFilterFactory(InMemoryRateLimiter rateLimiter, KeyReso

@Override
public GatewayFilter apply(Config config) {
this.rateLimiter.setParameters(config.capacity, config.tokens, config.refillDuration);
return (exchange, chain) -> {
List<PathContainer.Element> pathElements = exchange.getRequest().getPath().elements();
String requestPath = (!pathElements.isEmpty() && pathElements.size() > 1) ? pathElements.get(1).value() : null;
if (requestPath == null || !serviceIds.contains(requestPath)) {
if (requestPath == null) {
return chain.filter(exchange);
}
return keyResolver.resolve(exchange).flatMap(key -> {
if (key.isEmpty()) {
return chain.filter(exchange);
}
return rateLimiter.isAllowed(config.getRouteId(), key).flatMap(response -> {
return rateLimiter.isAllowed(requestPath, key).flatMap(response -> {
if (response.isAllowed()) {
return chain.filter(exchange);
} else {
Expand All @@ -87,9 +84,8 @@ public GatewayFilter apply(Config config) {
@Getter
@Setter
public static class Config {
private String routeId;
private Integer capacity;
private Integer tokens;
private Integer refillIntervalSeconds;
private int capacity;
private int tokens;
private int refillDuration;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public class RouteLocator implements RouteDefinitionLocator {
@Value("${apiml.service.forwardClientCertEnabled:false}")
private boolean forwardingClientCertEnabled;

@Value("${apiml.gateway.servicesToLimitRequestRate:-}")
List<String> servicesToLimitRequestRate;

private final ApplicationContext context;

private final CorsUtils corsUtils;
Expand Down Expand Up @@ -140,6 +143,21 @@ List<FilterDefinition> getPostRoutingFilters(ServiceInstance serviceInstance) {
serviceRelated.add(forbidEncodedCharactersFilter);
}

if (Optional.ofNullable(serviceInstance.getMetadata().get(APPLY_RATE_LIMITER_FILTER))
.map(Boolean::parseBoolean)
.orElse(false)) {
FilterDefinition rateLimiterFilter = new FilterDefinition();
rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory");
rateLimiterFilter.addArg("capacity", serviceInstance.getMetadata().get("apiml.gateway.rateLimiterCapacity"));
rateLimiterFilter.addArg("tokens", serviceInstance.getMetadata().get("apiml.gateway.rateLimiterTokens"));
rateLimiterFilter.addArg("refillDuration", serviceInstance.getMetadata().get("apiml.gateway.rateLimiterRefillDuration"));
serviceRelated.add(rateLimiterFilter);
} else if (servicesToLimitRequestRate != null && servicesToLimitRequestRate.contains(serviceInstance.getServiceId().toLowerCase())) {
FilterDefinition rateLimiterFilter = new FilterDefinition();
rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory");
serviceRelated.add(rateLimiterFilter);
}

FilterDefinition pageRedirectionFilter = new FilterDefinition();
pageRedirectionFilter.setName("PageRedirectionFilterFactory");
pageRedirectionFilter.addArg("serviceId", serviceInstance.getServiceId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -40,7 +39,6 @@ public class InMemoryRateLimiterFilterFactoryTest {
private InMemoryRateLimiterFilterFactory filterFactory;
private ServerWebExchange exchange;
private GatewayFilterChain chain;
private String serviceId;
private MockServerHttpRequest request;
private InMemoryRateLimiterFilterFactory.Config config;
private MessageService messageService;
Expand All @@ -49,19 +47,16 @@ public class InMemoryRateLimiterFilterFactoryTest {

@BeforeEach
public void setUp() {
serviceId = "testService";
rateLimiter = mock(InMemoryRateLimiter.class);
keyResolver = mock(KeyResolver.class);
messageService = mock(MessageService.class);
message = mock(Message.class);
objectMapper = mock(ObjectMapper.class);
filterFactory = new InMemoryRateLimiterFilterFactory(rateLimiter, keyResolver,objectMapper, messageService);
filterFactory.serviceIds = List.of(serviceId);
request = MockServerHttpRequest.get("/" + serviceId).build();
request = MockServerHttpRequest.get("/" + "serviceId").build();
exchange = MockServerWebExchange.from(request);
chain = mock(GatewayFilterChain.class);
config = mock(InMemoryRateLimiterFilterFactory.Config.class);
when(config.getRouteId()).thenReturn("testRoute");
}

@Test
Expand Down Expand Up @@ -115,20 +110,6 @@ public void apply_shouldAllowRequest_whenKeyIsNull() {
verify(chain, times(1)).filter(exchange);
}

@Test
public void apply_shouldAllowRequest_whenServiceIdDoesNotMatch() {
String nonMatchingId = "nonMatchingId";
when(keyResolver.resolve(exchange)).thenReturn(Mono.just("testKey"));
request = MockServerHttpRequest.get("/" + nonMatchingId).build();
exchange = MockServerWebExchange.from(request);
when(chain.filter(any(ServerWebExchange.class))).thenReturn(Mono.empty());

StepVerifier.create(filterFactory.apply(config).filter(exchange, chain))
.expectComplete()
.verify();
verify(chain, times(1)).filter(exchange);
}

@Test
public void apply_shouldAllowRequest_whenServiceIdEmpty() {
when(keyResolver.resolve(exchange)).thenReturn(Mono.just("testKey"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,23 @@ public void testNewConfig() {
assertEquals(rateLimiter.tokens, config.getTokens(), "Config tokens should match the rate limiter tokens");
assertEquals(rateLimiter.refillDuration, config.getRefillDuration(), "Config refill duration should match the rate limiter refill duration");
}

@Test
public void setNonNullParametersTest() {
Integer newCapacity = 20;
Integer newTokens = 20;
Integer newRefillDuration = 2;
rateLimiter.setParameters(newCapacity, newTokens, newRefillDuration);
assertEquals(newCapacity, rateLimiter.capacity);
assertEquals(newTokens, rateLimiter.tokens);
assertEquals(newRefillDuration, rateLimiter.refillDuration);
}

@Test
public void setParametersWithNullValuesTest() {
Integer newCapacity = 30;
rateLimiter.setParameters(newCapacity, 0, 0);
assertEquals(newCapacity, rateLimiter.capacity);

}
}
Loading

0 comments on commit cad63cb

Please sign in to comment.