diff --git a/src/tateyama/transport/client_wire.h b/src/tateyama/transport/client_wire.h index 1498393..c4fb627 100644 --- a/src/tateyama/transport/client_wire.h +++ b/src/tateyama/transport/client_wire.h @@ -18,6 +18,7 @@ #include #include #include +#include #include // std::runtime_error #include "wire.h" @@ -261,25 +262,29 @@ class session_wire_container slot& my_slot = slot_status_.at(static_cast(slot_index)); while (true) { + { + std::unique_lock lock(mtx_receive_); + cnd_receive_.wait(lock, [this, slot_index]{ return slot_status_.at(static_cast(slot_index)).valid() || !using_wire_.load(); }); + } if (my_slot.valid()) { my_slot.consume(res_message); + cnd_receive_.notify_all(); return; } - { - std::unique_lock lock(mtx_receive_); - - // check again to avoid race - if (my_slot.valid()) { - my_slot.consume(res_message); - return; - } + bool expected = false; + if (!using_wire_.compare_exchange_weak(expected, true)) { + continue; + } + try { auto header_received = response_wire_.await(); auto index_received = header_received.get_idx(); if (index_received == slot_index) { res_message.resize(header_received.get_length()); response_wire_.read(res_message.data()); my_slot.receive_and_consume(header_received.get_type()); + using_wire_.store(false); + cnd_receive_.notify_all(); return; } auto& slot_received = slot_status_.at(static_cast(index_received)); @@ -287,6 +292,12 @@ class session_wire_container message_received.resize(header_received.get_length()); response_wire_.read(message_received.data()); slot_received.post_receive(); + using_wire_.store(false); + cnd_receive_.notify_all(); + } catch (std::runtime_error& ex) { + using_wire_.store(false); + cnd_receive_.notify_all(); + throw ex; } } } @@ -304,6 +315,8 @@ class session_wire_container std::array slot_status_{}; std::mutex mtx_send_{}; std::mutex mtx_receive_{}; + std::condition_variable cnd_receive_{}; + std::atomic_bool using_wire_{}; void dispose_resultset_wire(std::unique_ptr& container) { container->set_closed(); diff --git a/src/tateyama/transport/transport.h b/src/tateyama/transport/transport.h index 332e51c..a156a34 100644 --- a/src/tateyama/transport/transport.h +++ b/src/tateyama/transport/transport.h @@ -83,7 +83,7 @@ class transport { session_id_ = handshake_response.success().session_id(); header_.set_session_id(session_id_); - time_ = std::make_unique(EXPIRATION_SECONDS, [this](){ + timer_ = std::make_unique(EXPIRATION_SECONDS, [this](){ auto ret = update_expiration_time(); if (ret.has_value()) { return ret.value().result_case() == tateyama::proto::core::response::UpdateExpirationTime::ResultCase::kSuccess; @@ -94,7 +94,7 @@ class transport { ~transport() { try { - time_ = nullptr; + timer_ = nullptr; if (!closed_) { close(); } @@ -365,7 +365,7 @@ class transport { std::unique_ptr status_info_{}; std::size_t session_id_{}; bool closed_{}; - std::unique_ptr time_{}; + std::unique_ptr timer_{}; std::string digest() { auto bst_conf = configuration::bootstrap_configuration::create_bootstrap_configuration(FLAGS_conf); diff --git a/test/tateyama/test_utils/endpoint.h b/test/tateyama/test_utils/endpoint.h index a7d1995..5304fcd 100644 --- a/test/tateyama/test_utils/endpoint.h +++ b/test/tateyama/test_utils/endpoint.h @@ -55,13 +55,24 @@ class endpoint_response { class endpoint { constexpr static tateyama::common::wire::response_header::msg_type RESPONSE_BODY = 1; + class data_for_check { + public: + std::queue responses_{}; + tateyama::framework::component::id_type component_id_{}; + std::string current_request_{}; + std::size_t uet_count_{}; + }; + public: class worker { public: - worker(std::size_t session_id, std::unique_ptr wire, std::function clean_up, std::queue& responses, tateyama::framework::component::id_type& component_id, std::string& current_request) - : session_id_(session_id), wire_(std::move(wire)), clean_up_(std::move(clean_up)), responses_(responses), component_id_(component_id), current_request_(current_request), thread_(std::thread(std::ref(*this))) { + worker(std::size_t session_id, std::unique_ptr wire, std::function clean_up, data_for_check& data_for_check) + : session_id_(session_id), wire_(std::move(wire)), clean_up_(std::move(clean_up)), data_for_check_(data_for_check), thread_(std::thread(std::ref(*this))) { } ~worker() { + if (reply_thread_.joinable()) { + reply_thread_.join(); + } if (thread_.joinable()) { thread_.join(); } @@ -81,7 +92,7 @@ class endpoint { throw std::runtime_error("error parsing request message"); } std::stringstream ss{}; - if (component_id_ = req_header.service_id(); component_id_ == tateyama::framework::service_id_endpoint_broker) { + if (data_for_check_.component_id_ = req_header.service_id(); data_for_check_.component_id_ == tateyama::framework::service_id_endpoint_broker) { ::tateyama::proto::framework::response::Header header{}; if(auto res = tateyama::utils::SerializeDelimitedToOstream(header, std::addressof(ss)); ! res) { throw std::runtime_error("error formatting response message"); @@ -102,6 +113,7 @@ class endpoint { if(auto res = tateyama::utils::SerializeDelimitedToOstream(header, std::addressof(ss)); ! res) { throw std::runtime_error("error formatting response message"); } + data_for_check_.uet_count_++; tateyama::proto::core::response::UpdateExpirationTime rp{}; (void) rp.mutable_success(); auto body = rp.SerializeAsString(); @@ -127,18 +139,38 @@ class endpoint { auto reply_message = ss.str(); wire_->get_response_wire().write(reply_message.data(), tateyama::common::wire::response_header(index, reply_message.length(), RESPONSE_BODY)); continue; + } else if (req_header.service_id() == static_cast(2468)) { + std::string_view payload{}; + if (auto res = tateyama::utils::GetDelimitedBodyFromZeroCopyStream(std::addressof(in), nullptr, payload); ! res) { + throw std::runtime_error("error reading payload"); + } + auto t = std::stoi(std::string(payload)); + reply_thread_ = std::thread([this, t, index]{ + std::this_thread::sleep_for(std::chrono::seconds(t)); + std::stringstream ss{}; + ::tateyama::proto::framework::response::Header header{}; + if(auto res = tateyama::utils::SerializeDelimitedToOstream(header, std::addressof(ss)); ! res) { + throw std::runtime_error("error formatting response message"); + } + if(auto res = tateyama::utils::PutDelimitedBodyToOstream(std::string("OK"), std::addressof(ss)); ! res) { + throw std::runtime_error("error formatting response message"); + } + auto reply_message = ss.str(); + wire_->get_response_wire().write(reply_message.data(), tateyama::common::wire::response_header(index, reply_message.length(), RESPONSE_BODY)); + }); + continue; } { std::string_view payload{}; if (auto res = tateyama::utils::GetDelimitedBodyFromZeroCopyStream(std::addressof(in), nullptr, payload); ! res) { throw std::runtime_error("error reading payload"); } - current_request_ = payload; - if (responses_.empty()) { + data_for_check_.current_request_ = payload; + if (data_for_check_.responses_.empty()) { throw std::runtime_error("response queue is empty"); } - auto reply = responses_.front(); - responses_.pop(); + auto reply = data_for_check_.responses_.front(); + data_for_check_.responses_.pop(); std::stringstream ss{}; ::tateyama::proto::framework::response::Header header{}; header.set_payload_type(reply.get_type()); @@ -155,15 +187,21 @@ class endpoint { } clean_up_(); } + void finish() { + wire_->finish(); + } + void suppress_message() { + wire_->suppress_message(); + } private: std::size_t session_id_; std::unique_ptr wire_; std::function clean_up_; - std::queue& responses_; - tateyama::framework::component::id_type& component_id_; - std::string& current_request_; + data_for_check& data_for_check_; std::thread thread_; + std::thread reply_thread_{}; + bool suppress_message_{}; }; endpoint(const std::string& name, const std::string& digest, boost::barrier& sync) @@ -193,6 +231,9 @@ class endpoint { } while(true) { auto session_id = connection_queue.listen(); + if (session_id == 0) { // means timeout + continue; + } if (connection_queue.is_terminated()) { connection_queue.confirm_terminated(); break; @@ -205,7 +246,10 @@ class endpoint { connection_queue.accept(index, session_id); try { std::unique_lock lk(mutex_); - worker_ = std::make_unique(session_id, std::move(wire), [&connection_queue, index](){ connection_queue.disconnect(index); }, responses_, component_id_, current_request_); + worker_ = std::make_unique(session_id, std::move(wire), [&connection_queue, index](){ connection_queue.disconnect(index); }, data_for_check_); + if (suppress_message_) { + worker_->suppress_message(); + } condition_.notify_all(); } catch (std::exception& ex) { LOG(ERROR) << ex.what(); @@ -214,17 +258,24 @@ class endpoint { } } void push_response(std::string_view response, tateyama::proto::framework::response::Header_PayloadType type) { - responses_.emplace(response, type); + data_for_check_.responses_.emplace(response, type); } const tateyama::framework::component::id_type component_id() const { - return component_id_; + return data_for_check_.component_id_; } const std::string& current_request() const { - return current_request_; + return data_for_check_.current_request_; + } + const std::size_t update_expiration_time_count() { + return data_for_check_.uet_count_; } void terminate() { + worker_->finish(); container_->get_connection_queue().request_terminate(); } + void suppress_message() { + suppress_message_ = true; + } private: std::string name_; @@ -235,10 +286,9 @@ class endpoint { std::unique_ptr worker_{}; std::mutex mutex_{}; std::condition_variable condition_{}; - std::queue responses_{}; bool notified_{false}; - tateyama::framework::component::id_type component_id_{}; - std::string current_request_{}; + data_for_check data_for_check_{}; + bool suppress_message_{}; }; } // namespace tateyama::test_utils diff --git a/test/tateyama/test_utils/server_mock.h b/test/tateyama/test_utils/server_mock.h index 083b056..d43629a 100644 --- a/test/tateyama/test_utils/server_mock.h +++ b/test/tateyama/test_utils/server_mock.h @@ -52,6 +52,12 @@ class server_mock { const std::string& current_request() const { return endpoint_.current_request(); } + const std::size_t update_expiration_time_count() { + return endpoint_.update_expiration_time_count(); + } + void suppress_message() { + endpoint_.suppress_message(); + } private: std::string name_; diff --git a/test/tateyama/test_utils/server_wires_mock.h b/test/tateyama/test_utils/server_wires_mock.h index 3ec8276..a297691 100644 --- a/test/tateyama/test_utils/server_wires_mock.h +++ b/test/tateyama/test_utils/server_wires_mock.h @@ -55,7 +55,19 @@ class server_wire_container_mock bip_buffer_ = bip_buffer; } tateyama::common::wire::message_header peep() { - return wire_->peep(bip_buffer_); + while (true) { + try { + return wire_->peep(bip_buffer_); + } catch (std::runtime_error &ex) { + if (!suppress_message_) { + std::cout << ex.what() << std::endl; + } + if (finish_) { + break; + } + } + } + return {tateyama::common::wire::message_header::terminate_request, 0}; } std::string_view payload() { return wire_->payload(bip_buffer_); @@ -66,6 +78,8 @@ class server_wire_container_mock std::size_t read_point() { return wire_->read_point(); } void dispose() { wire_->dispose(); } void notify() { wire_->notify(); } + void finish() { finish_ = true; } + void suppress_message() { suppress_message_ = true; } // for mainly client, except for terminate request from server void write(const char* from, const std::size_t len, tateyama::common::wire::message_header::index_type index) { @@ -75,6 +89,8 @@ class server_wire_container_mock private: tateyama::common::wire::unidirectional_message_wire* wire_{}; char* bip_buffer_{}; + bool finish_{}; + bool suppress_message_{}; }; class response_wire_container_mock { @@ -100,7 +116,18 @@ class server_wire_container_mock // for client tateyama::common::wire::response_header await() { - return wire_->await(bip_buffer_); + while (true) { + try { + return wire_->await(bip_buffer_); + } catch (std::runtime_error &ex) { + if (!suppress_message_) { + std::cout << ex.what() << std::endl; + } + if (finish_) { + break; + } + } + } } [[nodiscard]] tateyama::common::wire::response_header::length_type get_length() const { return wire_->get_length(); @@ -117,11 +144,15 @@ class server_wire_container_mock void close() { wire_->close(); } + void finish() { finish_ = true; } + void suppress_message() { suppress_message_ = true; } private: tateyama::common::wire::unidirectional_response_wire* wire_{}; char* bip_buffer_{}; std::mutex mtx_{}; + bool finish_{}; + bool suppress_message_{}; }; server_wire_container_mock(std::string_view name, std::string_view mutex_file) : name_(name) { @@ -167,6 +198,14 @@ class server_wire_container_mock request_wire_.notify(); response_wire_.notify_shutdown(); } + void finish() { + request_wire_.finish(); + response_wire_.finish(); + } + void suppress_message() { + request_wire_.suppress_message(); + response_wire_.suppress_message(); + } private: std::string name_; diff --git a/test/tateyama/transport/transport_test.cpp b/test/tateyama/transport/client_wire_test.cpp similarity index 95% rename from test/tateyama/transport/transport_test.cpp rename to test/tateyama/transport/client_wire_test.cpp index 9a366f9..85086e0 100644 --- a/test/tateyama/transport/transport_test.cpp +++ b/test/tateyama/transport/client_wire_test.cpp @@ -36,7 +36,7 @@ namespace tateyama::session { static constexpr std::size_t threads = 16; static constexpr std::size_t loops = 1024; -class transport_test : public ::testing::Test { +class client_wire_test : public ::testing::Test { public: class worker { constexpr static std::size_t HEADER_MESSAGE_VERSION_MAJOR = 0; @@ -121,10 +121,10 @@ class transport_test : public ::testing::Test { }; virtual void SetUp() { - helper_ = std::make_unique("transport_test", 20401); + helper_ = std::make_unique("client_wire_test", 20401); helper_->set_up(); auto bst_conf = tateyama::configuration::bootstrap_configuration::create_bootstrap_configuration(helper_->conf_file_path()); - server_mock_ = std::make_unique("transport_test", bst_conf.digest(), sync_); + server_mock_ = std::make_unique("client_wire_test", bst_conf.digest(), sync_); sync_.wait(); } @@ -139,8 +139,8 @@ class transport_test : public ::testing::Test { }; -TEST_F(transport_test, echo) { - tateyama::common::wire::session_wire_container wire(tateyama::common::wire::connection_container("transport_test").connect()); +TEST_F(client_wire_test, echo) { + tateyama::common::wire::session_wire_container wire(tateyama::common::wire::connection_container("client_wire_test").connect()); std::vector> workers{}; boost::barrier thread_sync{threads}; diff --git a/test/tateyama/transport/timer_test.cpp b/test/tateyama/transport/timer_test.cpp new file mode 100644 index 0000000..4c3582e --- /dev/null +++ b/test/tateyama/transport/timer_test.cpp @@ -0,0 +1,211 @@ +/* + * Copyright 2022-2024 Project Tsurugi. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "test_root.h" + +#include +#include +#include +#include +#include + +#include + +#include "tateyama/configuration/bootstrap_configuration.h" +#include "tateyama/transport/transport.h" +#include "tateyama/test_utils/server_mock.h" + +#include +#include + +namespace tateyama::session { + +static constexpr std::size_t threads = 16; +static constexpr std::size_t loops = 1024; + +class timer_test : public ::testing::Test { + constexpr static std::size_t EXPIRATION_SECONDS = 6; + constexpr static std::size_t SLEEP_TIME = 8; + + class transport_mock { + constexpr static std::size_t HEADER_MESSAGE_VERSION_MAJOR = 0; + constexpr static std::size_t HEADER_MESSAGE_VERSION_MINOR = 0; + constexpr static std::size_t CORE_MESSAGE_VERSION_MAJOR = 0; + constexpr static std::size_t CORE_MESSAGE_VERSION_MINOR = 0; + constexpr static tateyama::framework::component::id_type TYPE = 2468; + + public: + transport_mock(tateyama::common::wire::session_wire_container& wire) : wire_(wire) { + header_.set_service_message_version_major(HEADER_MESSAGE_VERSION_MAJOR); + header_.set_service_message_version_minor(HEADER_MESSAGE_VERSION_MINOR); + header_.set_service_id(TYPE); + + time_ = std::make_unique(EXPIRATION_SECONDS, [this](){ + auto ret = update_expiration_time(); + if (ret.has_value()) { + return ret.value().result_case() == tateyama::proto::core::response::UpdateExpirationTime::ResultCase::kSuccess; + } + return false; + }); + } + std::optional send(std::string_view request) { + std::stringstream ss{}; + if(auto res = tateyama::utils::SerializeDelimitedToOstream(header_, std::addressof(ss)); ! res) { + return std::nullopt; + } + if(auto res = tateyama::utils::PutDelimitedBodyToOstream(request, std::addressof(ss)); ! res) { + return std::nullopt; + } + auto index = wire_.search_slot(); + wire_.send(ss.str(), index); + + if (auto response_opt = receive(index); response_opt) { + return response_opt.value(); + } + return std::nullopt; + } + private: + tateyama::common::wire::session_wire_container& wire_; + tateyama::proto::framework::request::Header header_{}; + std::unique_ptr time_{}; + + std::optional receive(tateyama::common::wire::message_header::index_type index) { + std::string response_message{}; + + while (true) { + try { + wire_.receive(response_message, index); + break; + } catch (std::runtime_error &e) { + continue; + } + } + + ::tateyama::proto::framework::response::Header header{}; + google::protobuf::io::ArrayInputStream in{response_message.data(), static_cast(response_message.length())}; + if(auto res = tateyama::utils::ParseDelimitedFromZeroCopyStream(std::addressof(header), std::addressof(in), nullptr); ! res) { + return std::nullopt; + } + std::string_view response{}; + bool eof{}; + if(auto res = tateyama::utils::GetDelimitedBodyFromZeroCopyStream(std::addressof(in), &eof, response); ! res) { + return std::nullopt; + } + return std::string{response}; + } + + std::optional update_expiration_time() { + tateyama::proto::core::request::UpdateExpirationTime uet_request{}; + + tateyama::proto::core::request::Request request{}; + *(request.mutable_update_expiration_time()) = uet_request; + + return send(request); + } + template + std::optional send(::tateyama::proto::core::request::Request& request) { + tateyama::proto::framework::request::Header fwrq_header{}; + fwrq_header.set_service_message_version_major(HEADER_MESSAGE_VERSION_MAJOR); + fwrq_header.set_service_message_version_minor(HEADER_MESSAGE_VERSION_MINOR); + fwrq_header.set_service_id(tateyama::framework::service_id_routing); + + std::stringstream sst{}; + if(auto res = tateyama::utils::SerializeDelimitedToOstream(fwrq_header, std::addressof(sst)); ! res) { + return std::nullopt; + } + request.set_service_message_version_major(CORE_MESSAGE_VERSION_MAJOR); + request.set_service_message_version_minor(CORE_MESSAGE_VERSION_MINOR); + if(auto res = tateyama::utils::SerializeDelimitedToOstream(request, std::addressof(sst)); ! res) { + return std::nullopt; + } + auto slot_index = wire_.search_slot(); + wire_.send(sst.str(), slot_index); + + std::string res_message{}; + while (true) { + try { + wire_.receive(res_message, slot_index); + break; + } catch (std::runtime_error &e) { + std::cerr << e.what() << std::endl; + continue; + } + } + tateyama::proto::framework::response::Header fwrs_header{}; + google::protobuf::io::ArrayInputStream ins{res_message.data(), static_cast(res_message.length())}; + if(auto res = tateyama::utils::ParseDelimitedFromZeroCopyStream(std::addressof(fwrs_header), std::addressof(ins), nullptr); ! res) { + return std::nullopt; + } + std::string_view payload{}; + if (auto res = tateyama::utils::GetDelimitedBodyFromZeroCopyStream(std::addressof(ins), nullptr, payload); ! res) { + return std::nullopt; + } + T response{}; + if(auto res = response.ParseFromArray(payload.data(), payload.length()); ! res) { + return std::nullopt; + } + return response; + } + }; + +public: + class worker { + public: + worker(tateyama::common::wire::session_wire_container& wire) : transport_mock_(std::make_unique(wire)) { + } + void operator()() { + std::string time = std::to_string(SLEEP_TIME); + auto rv = transport_mock_->send(time); + if (!rv.has_value()) { + FAIL(); + } + } + + private: + std::unique_ptr transport_mock_{}; + }; + + virtual void SetUp() { + helper_ = std::make_unique("timer_test", 20402); + helper_->set_up(); + auto bst_conf = tateyama::configuration::bootstrap_configuration::create_bootstrap_configuration(helper_->conf_file_path()); + server_mock_ = std::make_unique("timer_test", bst_conf.digest(), sync_); + server_mock_->suppress_message(); + sync_.wait(); + } + + virtual void TearDown() { + helper_->tear_down(); + } + +protected: + std::unique_ptr helper_{}; + std::unique_ptr server_mock_{}; + boost::barrier sync_{2}; +}; + + +TEST_F(timer_test, fundamental) { + tateyama::common::wire::session_wire_container wire(tateyama::common::wire::connection_container("timer_test").connect()); + std::unique_ptr w = std::make_unique(wire); + std::thread thread = std::thread(std::ref(*w)); + + thread.join(); + EXPECT_GT(server_mock_->update_expiration_time_count(), 0); + wire.close(); +} + +} // namespace tateyama::session