Skip to content

Commit

Permalink
feat: add per ip limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
helintongh committed Mar 27, 2024
1 parent 47490f7 commit 6b592f3
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 30 deletions.
4 changes: 4 additions & 0 deletions include/cinatra/coro_http_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ class coro_http_connection
return ss.str();
}

std::string remote_ip_address() {
return socket_.remote_endpoint().address().to_string();
}

void set_multi_buf(bool r) { multi_buf_ = r; }

async_simple::coro::Lazy<bool> write_data(std::string_view message) {
Expand Down
71 changes: 59 additions & 12 deletions include/cinatra/coro_http_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ enum class file_resp_format_type {
chunked,
range,
};
enum class limiter_type {
disable,
global,
perip,
};
class coro_http_server {
public:
coro_http_server(asio::io_context &ctx, unsigned short port)
Expand Down Expand Up @@ -489,13 +494,24 @@ class coro_http_server {
return connections_.size();
}

void set_rate_limiter(bool is_enable, int gen_rate = 0, int burst_size = 0) {
need_rate_limiter_ = is_enable;
if (need_rate_limiter_) {
void set_rate_limiter(enum limiter_type type, int gen_rate = 0,
int burst_size = 0) {
limiter_type_ = type;
if (limiter_type_ == limiter_type::global) {
token_bucket_.reset(gen_rate, burst_size);
}
else if (limiter_type_ == limiter_type::perip) {
per_rate_ = gen_rate;
per_burst_size_ = burst_size;
}
else {
per_rate_ = 0.0;
per_burst_size_ = 0.0;
}
}

void clear_per_ip_limiter_cache() { req_limiter_.clear(); }

private:
std::errc listen() {
CINATRA_LOG_INFO << "begin to listen";
Expand Down Expand Up @@ -587,17 +603,44 @@ class coro_http_server {
connections_.emplace(conn_id, conn);
}

if (need_rate_limiter_) {
if (token_bucket_.consume(1)) {
// there are enough tokens to allow request.
switch (limiter_type_) {
case limiter_type::disable: {
start_one(conn).via(&conn->get_executor()).detach();
break;
}
else {
conn->close();
case limiter_type::global: {
if (token_bucket_.consume(1)) {
// there are enough tokens to allow request.
start_one(conn).via(&conn->get_executor()).detach();
}
else {
conn->close();
}
break;
}
case limiter_type::perip: {
if (req_limiter_.empty() ||
req_limiter_.find(conn->remote_ip_address()) ==
req_limiter_.end()) {
req_limiter_.insert(
std::make_pair<std::string, std::shared_ptr<token_bucket>>(
conn->remote_ip_address(),
std::make_shared<token_bucket>(per_rate_,
per_burst_size_)));
}

if (req_limiter_[conn->remote_ip_address()]->consume(1)) {
start_one(conn).via(&conn->get_executor()).detach();
}
else {
conn->close();
}
break;
}
default: {
start_one(conn).via(&conn->get_executor()).detach();
break;
}
}
else {
start_one(conn).via(&conn->get_executor()).detach();
}
}
}
Expand Down Expand Up @@ -805,10 +848,14 @@ class coro_http_server {
coro_http_router router_;
bool need_shrink_every_time_ = false;

bool need_rate_limiter_ = false;
enum limiter_type limiter_type_ = limiter_type::disable;
// 100 tokens are generated per second
// and the maximum number of token buckets is 100
token_bucket token_bucket_ = token_bucket{100, 100};

double per_rate_ = 0.0;
double per_burst_size_ = 0.0;
std::unordered_map<std::string, std::shared_ptr<token_bucket>> req_limiter_;
};

using http_server = coro_http_server;
Expand Down
106 changes: 90 additions & 16 deletions lang/token_bucket.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,45 +81,119 @@ bool consume_with_borrow_and_wait(double to_consume, double now_in_seconds = def
在coro_http_server.hpp中的`coro_http_server`类中添加令牌桶成员变量:

```cpp
bool need_rate_limiter_ = false;
enum limiter_type limiter_type_ = limiter_type::disable;
// 100 tokens are generated per second
// and the maximum number of token buckets is 100
token_bucket token_bucket_ = token_bucket{100, 100};
```
一个是是否需要速率限制的标记,一个是令牌桶
一个是是否需要速率限制的标记,一个是全局令牌桶
增加一个开启令牌桶功能的API:
目前有三个类型:
```cpp
void set_rate_limiter(bool is_enable, int gen_rate = 0, int burst_size = 0) {
need_rate_limiter_ = is_enable;
if (need_rate_limiter_) {
enum class limiter_type {
disable,
global,
perip,
};
```

默认为disable不开启,global为开启全局令牌桶,perip为开启每IP请求限制。


下面是开关令牌桶功能的API:

```cpp
void set_rate_limiter(enum limiter_type type, int gen_rate = 0,
int burst_size = 0) {
limiter_type_ = type;
if (limiter_type_ == limiter_type::global) {
token_bucket_.reset(gen_rate, burst_size);
}
else if (limiter_type_ == limiter_type::perip) {
per_rate_ = gen_rate;
per_burst_size_ = burst_size;
}
else {
per_rate_ = 0.0;
per_burst_size_ = 0.0;
}
}
```
最后将令牌桶逻辑嵌入到`accept()`函数中,将之前的`start_one(conn).via(&conn->get_executor()).detach()`改为如下:
```cpp

if (need_rate_limiter_) {
if (token_bucket_.consume(1)) {
// there are enough tokens to allow request.
switch (limiter_type_) {
case limiter_type::disable: {
start_one(conn).via(&conn->get_executor()).detach();
break;
}
else {
conn->close();
case limiter_type::global: {
if (token_bucket_.consume(1)) {
// there are enough tokens to allow request.
start_one(conn).via(&conn->get_executor()).detach();
}
else {
conn->close();
}
break;
}
case limiter_type::perip: {
// per ip功能后文解释,这里省略
break;
}
default: {
start_one(conn).via(&conn->get_executor()).detach();
break;
}
}
else {
start_one(conn).via(&conn->get_executor()).detach();
}
```

当没有设置流控时走原来的处理逻辑。当走流控时,每次请求来的时候消耗一个令牌,若桶中没有令牌了则直接关闭连接,此时客户端会收到404错误。

这就是全局令牌桶的逻辑。

通过提供的令牌桶算法,用户可以很轻松的通过cinatra的切片功能完成每IP流控等功能。
# cinatra每ip请求限制功能逻辑

新增三个变量,per_rate_记录每个ip的令牌生成速率(秒为单位),per_burst_size_为每个ip的令牌桶最大令牌数,一个hash表其中key为ip地址,value为该ip地址所拥有的令牌桶。

```cpp
double per_rate_ = 0.0;
double per_burst_size_ = 0.0;
std::unordered_map<std::string, std::shared_ptr<token_bucket>> req_limiter_;
```

逻辑是当客户端连接到cinatra服务端的时候,如果第一次到则生成该ip令牌桶,将该ip作为key,生成的令牌桶作为value插入到hash表中。

请求时需要根据客户端ip从hash表中取出其令牌桶,当该ip对应的令牌桶中令牌数大于1的时候运行请求,否则禁止其访问。此时客户端会收到404错误。

具体逻辑如下:

```cpp
if (req_limiter_.empty() ||
req_limiter_.find(conn->remote_ip_address()) ==
req_limiter_.end()) {
req_limiter_.insert(
std::make_pair<std::string, std::shared_ptr<token_bucket>>(
conn->remote_ip_address(),
std::make_shared<token_bucket>(per_rate_,
per_burst_size_)));
}

if (req_limiter_[conn->remote_ip_address()]->consume(1)) {
start_one(conn).via(&conn->get_executor()).detach();
}
else {
conn->close();
}
```
最后每IP请求限制还提供了一个清空缓存的API:
```cpp
void clear_per_ip_limiter_cache() { req_limiter_.clear(); }
```

缓存是容易被攻击,不停地向主机发送垃圾请求,此时就会因此不停地刷新cache造成每个请求在都会先在IP请求表中进行搜索。所以需要提供清理cache的函数。
36 changes: 34 additions & 2 deletions tests/test_coro_http_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1592,10 +1592,10 @@ TEST_CASE("test reverse proxy") {
CHECK(!resp_random.resp_body.empty());
}

TEST_CASE("test token bucket") {
TEST_CASE("test global rate limiter") {
cinatra::coro_http_server server(1, 9001);

server.set_rate_limiter(true, 3, 3);
server.set_rate_limiter(cinatra::limiter_type::global, 3, 3);
server.set_http_handler<cinatra::GET, cinatra::POST>(
"/",
[](coro_http_request &req,
Expand All @@ -1620,4 +1620,36 @@ TEST_CASE("test token bucket") {
std::this_thread::sleep_for(1s);
result = client5.get("http://127.0.0.1:9001/");
CHECK(result.status == 200);
}

TEST_CASE("test per ip rate limiter") {
cinatra::coro_http_server server(1, 9001);

server.set_rate_limiter(cinatra::limiter_type::perip, 3, 3);
server.set_http_handler<cinatra::GET, cinatra::POST>(
"/",
[](coro_http_request &req,
coro_http_response &response) -> async_simple::coro::Lazy<void> {
co_await coro_io::post([&]() {
response.set_status_and_content(status_type::ok, "ok");
});
});

server.async_start();
std::this_thread::sleep_for(1s);

coro_http_client client1, client2, client3, client4, client5, client6;
auto result = client1.get("http://127.0.0.1:9001/");
CHECK(result.status == 200);
result = client2.get("http://127.0.0.1:9001/");
CHECK(result.status == 200);
result = client3.get("http://127.0.0.1:9001/");
CHECK(result.status == 200);
result = client4.get("http://127.0.0.1:9001/");
CHECK(result.status == 404);
result = client5.get("http://127.0.0.1:9001/");
CHECK(result.status == 404);
std::this_thread::sleep_for(1s);
result = client6.get("http://127.0.0.1:9001/");
CHECK(result.status == 200);
}

0 comments on commit 6b592f3

Please sign in to comment.