diff --git a/include/cinatra/coro_http_connection.hpp b/include/cinatra/coro_http_connection.hpp index f0590379..2bef3b48 100644 --- a/include/cinatra/coro_http_connection.hpp +++ b/include/cinatra/coro_http_connection.hpp @@ -405,6 +405,10 @@ class coro_http_connection return ss.str(); } + std::string remote_ip_address() { + return socket_.remote_endpoint().address().to_string(); + } + void set_multi_buf(bool r) { multi_buf_ = r; } async_simple::coro::Lazy write_data(std::string_view message) { diff --git a/include/cinatra/coro_http_server.hpp b/include/cinatra/coro_http_server.hpp index 4e294b93..a9d7c3d5 100644 --- a/include/cinatra/coro_http_server.hpp +++ b/include/cinatra/coro_http_server.hpp @@ -26,6 +26,11 @@ enum class file_resp_format_type { chunked, range, }; +enum class limiter_type { + disable, + global, + perip, +}; class coro_http_server { public: coro_http_server(asio::io_context &ctx, unsigned short port) @@ -489,13 +494,24 @@ class coro_http_server { return connections_.size(); } - void set_rate_limiter(bool is_enable, int gen_rate = 0, int burst_size = 0) { - need_rate_limiter_ = is_enable; - if (need_rate_limiter_) { + void set_rate_limiter(enum limiter_type type, int gen_rate = 0, + int burst_size = 0) { + limiter_type_ = type; + if (limiter_type_ == limiter_type::global) { token_bucket_.reset(gen_rate, burst_size); } + else if (limiter_type_ == limiter_type::perip) { + per_rate_ = gen_rate; + per_burst_size_ = burst_size; + } + else { + per_rate_ = 0.0; + per_burst_size_ = 0.0; + } } + void clear_per_ip_limiter_cache() { req_limiter_.clear(); } + private: std::errc listen() { CINATRA_LOG_INFO << "begin to listen"; @@ -587,17 +603,44 @@ class coro_http_server { connections_.emplace(conn_id, conn); } - if (need_rate_limiter_) { - if (token_bucket_.consume(1)) { - // there are enough tokens to allow request. + switch (limiter_type_) { + case limiter_type::disable: { start_one(conn).via(&conn->get_executor()).detach(); + break; } - else { - conn->close(); + case limiter_type::global: { + if (token_bucket_.consume(1)) { + // there are enough tokens to allow request. + start_one(conn).via(&conn->get_executor()).detach(); + } + else { + conn->close(); + } + break; + } + case limiter_type::perip: { + if (req_limiter_.empty() || + req_limiter_.find(conn->remote_ip_address()) == + req_limiter_.end()) { + req_limiter_.insert( + std::make_pair>( + conn->remote_ip_address(), + std::make_shared(per_rate_, + per_burst_size_))); + } + + if (req_limiter_[conn->remote_ip_address()]->consume(1)) { + start_one(conn).via(&conn->get_executor()).detach(); + } + else { + conn->close(); + } + break; + } + default: { + start_one(conn).via(&conn->get_executor()).detach(); + break; } - } - else { - start_one(conn).via(&conn->get_executor()).detach(); } } } @@ -805,10 +848,14 @@ class coro_http_server { coro_http_router router_; bool need_shrink_every_time_ = false; - bool need_rate_limiter_ = false; + enum limiter_type limiter_type_ = limiter_type::disable; // 100 tokens are generated per second // and the maximum number of token buckets is 100 token_bucket token_bucket_ = token_bucket{100, 100}; + + double per_rate_ = 0.0; + double per_burst_size_ = 0.0; + std::unordered_map> req_limiter_; }; using http_server = coro_http_server; diff --git a/lang/token_bucket.md b/lang/token_bucket.md index 7ad8e7fa..a51f59f9 100644 --- a/lang/token_bucket.md +++ b/lang/token_bucket.md @@ -81,40 +81,73 @@ bool consume_with_borrow_and_wait(double to_consume, double now_in_seconds = def 在coro_http_server.hpp中的`coro_http_server`类中添加令牌桶成员变量: ```cpp - bool need_rate_limiter_ = false; + enum limiter_type limiter_type_ = limiter_type::disable; // 100 tokens are generated per second // and the maximum number of token buckets is 100 token_bucket token_bucket_ = token_bucket{100, 100}; ``` -一个是是否需要速率限制的标记,一个是令牌桶。 +一个是是否需要速率限制的标记,一个是全局令牌桶。 -增加一个开启令牌桶功能的API: +目前有三个类型: ```cpp - void set_rate_limiter(bool is_enable, int gen_rate = 0, int burst_size = 0) { - need_rate_limiter_ = is_enable; - if (need_rate_limiter_) { +enum class limiter_type { + disable, + global, + perip, +}; +``` + +默认为disable不开启,global为开启全局令牌桶,perip为开启每IP请求限制。 + + +下面是开关令牌桶功能的API: + +```cpp + void set_rate_limiter(enum limiter_type type, int gen_rate = 0, + int burst_size = 0) { + limiter_type_ = type; + if (limiter_type_ == limiter_type::global) { token_bucket_.reset(gen_rate, burst_size); } + else if (limiter_type_ == limiter_type::perip) { + per_rate_ = gen_rate; + per_burst_size_ = burst_size; + } + else { + per_rate_ = 0.0; + per_burst_size_ = 0.0; + } } ``` 最后将令牌桶逻辑嵌入到`accept()`函数中,将之前的`start_one(conn).via(&conn->get_executor()).detach()`改为如下: ```cpp - -if (need_rate_limiter_) { - if (token_bucket_.consume(1)) { - // there are enough tokens to allow request. +switch (limiter_type_) { + case limiter_type::disable: { start_one(conn).via(&conn->get_executor()).detach(); + break; } - else { - conn->close(); + case limiter_type::global: { + if (token_bucket_.consume(1)) { + // there are enough tokens to allow request. + start_one(conn).via(&conn->get_executor()).detach(); + } + else { + conn->close(); + } + break; + } + case limiter_type::perip: { + // per ip功能后文解释,这里省略 + break; + } + default: { + start_one(conn).via(&conn->get_executor()).detach(); + break; } -} -else { - start_one(conn).via(&conn->get_executor()).detach(); } ``` @@ -122,4 +155,45 @@ else { 这就是全局令牌桶的逻辑。 -通过提供的令牌桶算法,用户可以很轻松的通过cinatra的切片功能完成每IP流控等功能。 \ No newline at end of file +# cinatra每ip请求限制功能逻辑 + +新增三个变量,per_rate_记录每个ip的令牌生成速率(秒为单位),per_burst_size_为每个ip的令牌桶最大令牌数,一个hash表其中key为ip地址,value为该ip地址所拥有的令牌桶。 + +```cpp + double per_rate_ = 0.0; + double per_burst_size_ = 0.0; + std::unordered_map> req_limiter_; +``` + +逻辑是当客户端连接到cinatra服务端的时候,如果第一次到则生成该ip令牌桶,将该ip作为key,生成的令牌桶作为value插入到hash表中。 + +请求时需要根据客户端ip从hash表中取出其令牌桶,当该ip对应的令牌桶中令牌数大于1的时候运行请求,否则禁止其访问。此时客户端会收到404错误。 + +具体逻辑如下: + +```cpp +if (req_limiter_.empty() || + req_limiter_.find(conn->remote_ip_address()) == + req_limiter_.end()) { + req_limiter_.insert( + std::make_pair>( + conn->remote_ip_address(), + std::make_shared(per_rate_, + per_burst_size_))); +} + +if (req_limiter_[conn->remote_ip_address()]->consume(1)) { + start_one(conn).via(&conn->get_executor()).detach(); +} +else { + conn->close(); +} +``` + +最后每IP请求限制还提供了一个清空缓存的API: + +```cpp +void clear_per_ip_limiter_cache() { req_limiter_.clear(); } +``` + +缓存是容易被攻击,不停地向主机发送垃圾请求,此时就会因此不停地刷新cache造成每个请求在都会先在IP请求表中进行搜索。所以需要提供清理cache的函数。 \ No newline at end of file diff --git a/tests/test_coro_http_server.cpp b/tests/test_coro_http_server.cpp index 51ad3106..e1196d7a 100644 --- a/tests/test_coro_http_server.cpp +++ b/tests/test_coro_http_server.cpp @@ -1592,10 +1592,10 @@ TEST_CASE("test reverse proxy") { CHECK(!resp_random.resp_body.empty()); } -TEST_CASE("test token bucket") { +TEST_CASE("test global rate limiter") { cinatra::coro_http_server server(1, 9001); - server.set_rate_limiter(true, 3, 3); + server.set_rate_limiter(cinatra::limiter_type::global, 3, 3); server.set_http_handler( "/", [](coro_http_request &req, @@ -1620,4 +1620,36 @@ TEST_CASE("test token bucket") { std::this_thread::sleep_for(1s); result = client5.get("http://127.0.0.1:9001/"); CHECK(result.status == 200); +} + +TEST_CASE("test per ip rate limiter") { + cinatra::coro_http_server server(1, 9001); + + server.set_rate_limiter(cinatra::limiter_type::perip, 3, 3); + server.set_http_handler( + "/", + [](coro_http_request &req, + coro_http_response &response) -> async_simple::coro::Lazy { + co_await coro_io::post([&]() { + response.set_status_and_content(status_type::ok, "ok"); + }); + }); + + server.async_start(); + std::this_thread::sleep_for(1s); + + coro_http_client client1, client2, client3, client4, client5, client6; + auto result = client1.get("http://127.0.0.1:9001/"); + CHECK(result.status == 200); + result = client2.get("http://127.0.0.1:9001/"); + CHECK(result.status == 200); + result = client3.get("http://127.0.0.1:9001/"); + CHECK(result.status == 200); + result = client4.get("http://127.0.0.1:9001/"); + CHECK(result.status == 404); + result = client5.get("http://127.0.0.1:9001/"); + CHECK(result.status == 404); + std::this_thread::sleep_for(1s); + result = client6.get("http://127.0.0.1:9001/"); + CHECK(result.status == 200); } \ No newline at end of file