Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct cleanup on repeated waiting for never arriving messages #43

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions include/mav/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ namespace mav {
using Expectation = std::shared_ptr<std::promise<Message>>;

private:
using ExpectationWeakRef = std::weak_ptr<std::promise<Message>>;

struct FunctionCallback {
std::function<void(const Message &message)> callback;
std::function<void(const std::exception_ptr& exception)> error_callback;
};

struct PromiseCallback {
Expectation promise;
ExpectationWeakRef promise;
std::function<bool(const Message &message)> selector;
};

Expand All @@ -88,6 +89,11 @@ namespace mav {

public:

size_t callbackCount() {
std::scoped_lock<std::mutex> lock(_message_callback_mtx);
return _message_callbacks.size();
}

void removeAllCallbacks() {
std::scoped_lock<std::mutex> lock(_message_callback_mtx);
_message_callbacks.clear();
Expand All @@ -102,7 +108,7 @@ namespace mav {
return _partner;
}

void consumeMessageFromNetwork(const Message& message) {
void consumeMessageFromNetwork(const Message& message) noexcept {
// in case we received a heartbeat, update last heartbeat time, to keep the connection alive.
_last_received_ms = millis();

Expand All @@ -121,19 +127,24 @@ namespace mav {
}
it++;
} else if constexpr (std::is_same_v<T, PromiseCallback>) {
if (arg.selector(message)) {
arg.promise->set_value(message);
auto promise = arg.promise.lock();
if (!promise) {
it = _message_callbacks.erase(it);
} else {
it++;
if (arg.selector(message)) {
promise->set_value(message);
it = _message_callbacks.erase(it);
} else {
it++;
}
}
}
}, callback);
}
}
}

void consumeNetworkExceptionFromNetwork(const std::exception_ptr& exception) {
void consumeNetworkExceptionFromNetwork(const std::exception_ptr& exception) noexcept {
_underlying_network_fault = true;
std::scoped_lock<std::mutex> lock(_message_callback_mtx);
auto it = _message_callbacks.begin();
Expand All @@ -147,8 +158,13 @@ namespace mav {
}
it++;
} else if constexpr (std::is_same_v<T, PromiseCallback>) {
arg.promise->set_exception(exception);
it = _message_callbacks.erase(it);
auto promise = arg.promise.lock();
if (!promise) {
it = _message_callbacks.erase(it);
} else {
promise->set_exception(exception);
it = _message_callbacks.erase(it);
}
}
}, callback);
}
Expand Down
44 changes: 43 additions & 1 deletion tests/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,24 @@ TEST_CASE("Create network runtime") {
CHECK_THROWS_AS(auto message = connection->receive(expectation, 100), TimeoutException);
}

SUBCASE("Receive throws a NetworkError if the interface fails") {
SUBCASE("Receive throws a NetworkError if the interface fails, error callback gets called") {
interface.reset();

// add a callback using the callback API. The error should then call the error callback
auto error_callback_called_promise = std::promise<void>();
connection->addMessageCallback([](const Message &message) {
// do nothing
}, [&error_callback_called_promise](const std::exception_ptr& exception) {
error_callback_called_promise.set_value();
CHECK_THROWS_AS(std::rethrow_exception(exception), NetworkError);
});

auto expectation = connection->expect("TEST_MESSAGE");
interface.makeFailOnNextReceive();
// Receive on the sync api. The receive should then throw an exception
CHECK_THROWS_AS(auto message = connection->receive(expectation), NetworkError);
CHECK((error_callback_called_promise.get_future().wait_for(std::chrono::seconds(2)) != std::future_status::timeout));
connection->removeAllCallbacks();
}

SUBCASE("Connection recycled on recover after fail") {
Expand Down Expand Up @@ -265,4 +278,33 @@ TEST_CASE("Create network runtime") {
interface_partner));
CHECK(found);
}

SUBCASE("Correct callback called when message is received") {
interface.reset();
std::promise<void> callback_called_promise;
auto callback_called_future = callback_called_promise.get_future();

connection->addMessageCallback([&callback_called_promise](const Message &message) {
if (message.name() == "TEST_MESSAGE") {
callback_called_promise.set_value();
}
});

interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x61\x61\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x53\xd9"s, interface_partner);
CHECK((callback_called_future.wait_for(std::chrono::seconds(2)) != std::future_status::timeout));
connection->removeAllCallbacks();
}

SUBCASE("Callbacks are cleaned up on receive timeout") {
interface.reset();
for (int i = 0; i < 10; i++) {
auto expectation = connection->expect("TEST_MESSAGE");
CHECK_THROWS_AS(auto message = connection->receive(expectation, 100), TimeoutException);
}
// send a heartbeat. Any message will clear expired expectations
interface.addToReceiveQueue("\xfd\x09\x00\x00\x00\xfd\x01\x00\x00\x00\x04\x00\x00\x00\x01\x02\x03\x05\x06\x77\x53"s, interface_partner);
// wait for the heartbeat to be received, to make sure timing is correct in test
connection->receive("HEARTBEAT");
CHECK_EQ(connection->callbackCount(), 0);
}
}
Loading