Skip to content

Commit

Permalink
revise client_wire.h and add timer_test.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
t-horikawa committed Aug 6, 2024
1 parent 6815600 commit a4bf733
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 35 deletions.
29 changes: 21 additions & 8 deletions src/tateyama/transport/client_wire.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <atomic>
#include <array>
#include <mutex>
#include <condition_variable>
#include <stdexcept> // std::runtime_error

#include "wire.h"
Expand Down Expand Up @@ -261,32 +262,42 @@ class session_wire_container
slot& my_slot = slot_status_.at(static_cast<std::size_t>(slot_index));

while (true) {
{
std::unique_lock<std::mutex> lock(mtx_receive_);
cnd_receive_.wait(lock, [this, slot_index]{ return slot_status_.at(static_cast<std::size_t>(slot_index)).valid() || !using_wire_.load(); });
}
if (my_slot.valid()) {
my_slot.consume(res_message);
cnd_receive_.notify_all();
return;
}
{
std::unique_lock<std::mutex> lock(mtx_receive_);

// check again to avoid race
if (my_slot.valid()) {
my_slot.consume(res_message);
return;
}
bool expected = false;
if (!using_wire_.compare_exchange_weak(expected, true)) {
continue;
}

try {
auto header_received = response_wire_.await();
auto index_received = header_received.get_idx();
if (index_received == slot_index) {
res_message.resize(header_received.get_length());
response_wire_.read(res_message.data());
my_slot.receive_and_consume(header_received.get_type());
using_wire_.store(false);
cnd_receive_.notify_all();
return;
}
auto& slot_received = slot_status_.at(static_cast<std::size_t>(index_received));
std::string& message_received = slot_received.pre_receive(header_received.get_type());
message_received.resize(header_received.get_length());
response_wire_.read(message_received.data());
slot_received.post_receive();
using_wire_.store(false);
cnd_receive_.notify_all();
} catch (std::runtime_error& ex) {
using_wire_.store(false);
cnd_receive_.notify_all();
throw ex;
}
}
}
Expand All @@ -304,6 +315,8 @@ class session_wire_container
std::array<slot, slot_size> slot_status_{};
std::mutex mtx_send_{};
std::mutex mtx_receive_{};
std::condition_variable cnd_receive_{};
std::atomic_bool using_wire_{};

void dispose_resultset_wire(std::unique_ptr<resultset_wires_container>& container) {
container->set_closed();
Expand Down
6 changes: 3 additions & 3 deletions src/tateyama/transport/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class transport {
session_id_ = handshake_response.success().session_id();
header_.set_session_id(session_id_);

time_ = std::make_unique<tateyama::common::wire::timer>(EXPIRATION_SECONDS, [this](){
timer_ = std::make_unique<tateyama::common::wire::timer>(EXPIRATION_SECONDS, [this](){
auto ret = update_expiration_time();
if (ret.has_value()) {
return ret.value().result_case() == tateyama::proto::core::response::UpdateExpirationTime::ResultCase::kSuccess;
Expand All @@ -94,7 +94,7 @@ class transport {

~transport() {
try {
time_ = nullptr;
timer_ = nullptr;
if (!closed_) {
close();
}
Expand Down Expand Up @@ -365,7 +365,7 @@ class transport {
std::unique_ptr<tateyama::server::status_info_bridge> status_info_{};
std::size_t session_id_{};
bool closed_{};
std::unique_ptr<tateyama::common::wire::timer> time_{};
std::unique_ptr<tateyama::common::wire::timer> timer_{};

std::string digest() {
auto bst_conf = configuration::bootstrap_configuration::create_bootstrap_configuration(FLAGS_conf);
Expand Down
84 changes: 67 additions & 17 deletions test/tateyama/test_utils/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,24 @@ class endpoint_response {
class endpoint {
constexpr static tateyama::common::wire::response_header::msg_type RESPONSE_BODY = 1;

class data_for_check {
public:
std::queue<endpoint_response> responses_{};
tateyama::framework::component::id_type component_id_{};
std::string current_request_{};
std::size_t uet_count_{};
};

public:
class worker {
public:
worker(std::size_t session_id, std::unique_ptr<tateyama::test_utils::server_wire_container_mock> wire, std::function<void(void)> clean_up, std::queue<endpoint_response>& responses, tateyama::framework::component::id_type& component_id, std::string& current_request)
: session_id_(session_id), wire_(std::move(wire)), clean_up_(std::move(clean_up)), responses_(responses), component_id_(component_id), current_request_(current_request), thread_(std::thread(std::ref(*this))) {
worker(std::size_t session_id, std::unique_ptr<tateyama::test_utils::server_wire_container_mock> wire, std::function<void(void)> clean_up, data_for_check& data_for_check)
: session_id_(session_id), wire_(std::move(wire)), clean_up_(std::move(clean_up)), data_for_check_(data_for_check), thread_(std::thread(std::ref(*this))) {
}
~worker() {
if (reply_thread_.joinable()) {
reply_thread_.join();
}
if (thread_.joinable()) {
thread_.join();
}
Expand All @@ -81,7 +92,7 @@ class endpoint {
throw std::runtime_error("error parsing request message");
}
std::stringstream ss{};
if (component_id_ = req_header.service_id(); component_id_ == tateyama::framework::service_id_endpoint_broker) {
if (data_for_check_.component_id_ = req_header.service_id(); data_for_check_.component_id_ == tateyama::framework::service_id_endpoint_broker) {
::tateyama::proto::framework::response::Header header{};
if(auto res = tateyama::utils::SerializeDelimitedToOstream(header, std::addressof(ss)); ! res) {
throw std::runtime_error("error formatting response message");
Expand All @@ -102,6 +113,7 @@ class endpoint {
if(auto res = tateyama::utils::SerializeDelimitedToOstream(header, std::addressof(ss)); ! res) {
throw std::runtime_error("error formatting response message");
}
data_for_check_.uet_count_++;
tateyama::proto::core::response::UpdateExpirationTime rp{};
(void) rp.mutable_success();
auto body = rp.SerializeAsString();
Expand All @@ -127,18 +139,38 @@ class endpoint {
auto reply_message = ss.str();
wire_->get_response_wire().write(reply_message.data(), tateyama::common::wire::response_header(index, reply_message.length(), RESPONSE_BODY));
continue;
} else if (req_header.service_id() == static_cast<tateyama::framework::component::id_type>(2468)) {
std::string_view payload{};
if (auto res = tateyama::utils::GetDelimitedBodyFromZeroCopyStream(std::addressof(in), nullptr, payload); ! res) {
throw std::runtime_error("error reading payload");
}
auto t = std::stoi(std::string(payload));
reply_thread_ = std::thread([this, t, index]{
std::this_thread::sleep_for(std::chrono::seconds(t));
std::stringstream ss{};
::tateyama::proto::framework::response::Header header{};
if(auto res = tateyama::utils::SerializeDelimitedToOstream(header, std::addressof(ss)); ! res) {
throw std::runtime_error("error formatting response message");
}
if(auto res = tateyama::utils::PutDelimitedBodyToOstream(std::string("OK"), std::addressof(ss)); ! res) {
throw std::runtime_error("error formatting response message");
}
auto reply_message = ss.str();
wire_->get_response_wire().write(reply_message.data(), tateyama::common::wire::response_header(index, reply_message.length(), RESPONSE_BODY));
});
continue;
}
{
std::string_view payload{};
if (auto res = tateyama::utils::GetDelimitedBodyFromZeroCopyStream(std::addressof(in), nullptr, payload); ! res) {
throw std::runtime_error("error reading payload");
}
current_request_ = payload;
if (responses_.empty()) {
data_for_check_.current_request_ = payload;
if (data_for_check_.responses_.empty()) {
throw std::runtime_error("response queue is empty");
}
auto reply = responses_.front();
responses_.pop();
auto reply = data_for_check_.responses_.front();
data_for_check_.responses_.pop();
std::stringstream ss{};
::tateyama::proto::framework::response::Header header{};
header.set_payload_type(reply.get_type());
Expand All @@ -155,15 +187,21 @@ class endpoint {
}
clean_up_();
}
void finish() {
wire_->finish();
}
void suppress_message() {
wire_->suppress_message();
}

private:
std::size_t session_id_;
std::unique_ptr<tateyama::test_utils::server_wire_container_mock> wire_;
std::function<void(void)> clean_up_;
std::queue<endpoint_response>& responses_;
tateyama::framework::component::id_type& component_id_;
std::string& current_request_;
data_for_check& data_for_check_;
std::thread thread_;
std::thread reply_thread_{};
bool suppress_message_{};
};

endpoint(const std::string& name, const std::string& digest, boost::barrier& sync)
Expand Down Expand Up @@ -193,6 +231,9 @@ class endpoint {
}
while(true) {
auto session_id = connection_queue.listen();
if (session_id == 0) { // means timeout
continue;
}
if (connection_queue.is_terminated()) {
connection_queue.confirm_terminated();
break;
Expand All @@ -205,7 +246,10 @@ class endpoint {
connection_queue.accept(index, session_id);
try {
std::unique_lock<std::mutex> lk(mutex_);
worker_ = std::make_unique<worker>(session_id, std::move(wire), [&connection_queue, index](){ connection_queue.disconnect(index); }, responses_, component_id_, current_request_);
worker_ = std::make_unique<worker>(session_id, std::move(wire), [&connection_queue, index](){ connection_queue.disconnect(index); }, data_for_check_);
if (suppress_message_) {
worker_->suppress_message();
}
condition_.notify_all();
} catch (std::exception& ex) {
LOG(ERROR) << ex.what();
Expand All @@ -214,17 +258,24 @@ class endpoint {
}
}
void push_response(std::string_view response, tateyama::proto::framework::response::Header_PayloadType type) {
responses_.emplace(response, type);
data_for_check_.responses_.emplace(response, type);
}
const tateyama::framework::component::id_type component_id() const {
return component_id_;
return data_for_check_.component_id_;
}
const std::string& current_request() const {
return current_request_;
return data_for_check_.current_request_;
}
const std::size_t update_expiration_time_count() {
return data_for_check_.uet_count_;
}
void terminate() {
worker_->finish();
container_->get_connection_queue().request_terminate();
}
void suppress_message() {
suppress_message_ = true;
}

private:
std::string name_;
Expand All @@ -235,10 +286,9 @@ class endpoint {
std::unique_ptr<worker> worker_{};
std::mutex mutex_{};
std::condition_variable condition_{};
std::queue<endpoint_response> responses_{};
bool notified_{false};
tateyama::framework::component::id_type component_id_{};
std::string current_request_{};
data_for_check data_for_check_{};
bool suppress_message_{};
};

} // namespace tateyama::test_utils
6 changes: 6 additions & 0 deletions test/tateyama/test_utils/server_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class server_mock {
const std::string& current_request() const {
return endpoint_.current_request();
}
const std::size_t update_expiration_time_count() {
return endpoint_.update_expiration_time_count();
}
void suppress_message() {
endpoint_.suppress_message();
}

private:
std::string name_;
Expand Down
43 changes: 41 additions & 2 deletions test/tateyama/test_utils/server_wires_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,19 @@ class server_wire_container_mock
bip_buffer_ = bip_buffer;
}
tateyama::common::wire::message_header peep() {
return wire_->peep(bip_buffer_);
while (true) {
try {
return wire_->peep(bip_buffer_);
} catch (std::runtime_error &ex) {
if (!suppress_message_) {
std::cout << ex.what() << std::endl;
}
if (finish_) {
break;
}
}
}
return {tateyama::common::wire::message_header::terminate_request, 0};
}
std::string_view payload() {
return wire_->payload(bip_buffer_);
Expand All @@ -66,6 +78,8 @@ class server_wire_container_mock
std::size_t read_point() { return wire_->read_point(); }
void dispose() { wire_->dispose(); }
void notify() { wire_->notify(); }
void finish() { finish_ = true; }
void suppress_message() { suppress_message_ = true; }

// for mainly client, except for terminate request from server
void write(const char* from, const std::size_t len, tateyama::common::wire::message_header::index_type index) {
Expand All @@ -75,6 +89,8 @@ class server_wire_container_mock
private:
tateyama::common::wire::unidirectional_message_wire* wire_{};
char* bip_buffer_{};
bool finish_{};
bool suppress_message_{};
};

class response_wire_container_mock {
Expand All @@ -100,7 +116,18 @@ class server_wire_container_mock

// for client
tateyama::common::wire::response_header await() {
return wire_->await(bip_buffer_);
while (true) {
try {
return wire_->await(bip_buffer_);
} catch (std::runtime_error &ex) {
if (!suppress_message_) {
std::cout << ex.what() << std::endl;
}
if (finish_) {
break;
}
}
}
}
[[nodiscard]] tateyama::common::wire::response_header::length_type get_length() const {
return wire_->get_length();
Expand All @@ -117,11 +144,15 @@ class server_wire_container_mock
void close() {
wire_->close();
}
void finish() { finish_ = true; }
void suppress_message() { suppress_message_ = true; }

private:
tateyama::common::wire::unidirectional_response_wire* wire_{};
char* bip_buffer_{};
std::mutex mtx_{};
bool finish_{};
bool suppress_message_{};
};

server_wire_container_mock(std::string_view name, std::string_view mutex_file) : name_(name) {
Expand Down Expand Up @@ -167,6 +198,14 @@ class server_wire_container_mock
request_wire_.notify();
response_wire_.notify_shutdown();
}
void finish() {
request_wire_.finish();
response_wire_.finish();
}
void suppress_message() {
request_wire_.suppress_message();
response_wire_.suppress_message();
}

private:
std::string name_;
Expand Down
Loading

0 comments on commit a4bf733

Please sign in to comment.