diff --git a/src/applications/tests/unit_tests/cooked_send_test.cpp b/src/applications/tests/unit_tests/cooked_send_test.cpp index 49a35d37..31154496 100644 --- a/src/applications/tests/unit_tests/cooked_send_test.cpp +++ b/src/applications/tests/unit_tests/cooked_send_test.cpp @@ -12,16 +12,46 @@ using std::map; using std::pair; using std::vector; +/** + * This object contains state that is shared between the replicated test objects + * and the main thread, rather than stored inside the replicated objects. It's + * used to provide a way for the replicated objects to "call back" to the main + * thread. Each replicated object will get a pointer to this object when it is + * constructed or deserialized, set up by the deserialization manager. + */ +struct TestState : public derecho::DeserializationContext { + uint32_t num_messages; + uint32_t num_nodes; + uint32_t counter; + std::atomic done; + void check_test_done() { + if(counter == num_messages * num_nodes) { + done = true; + } + } + TestState(uint32_t num_messages, uint32_t num_nodes) + : num_messages(num_messages), + num_nodes(num_nodes), + counter(0), + done(false) {} +}; + class CookedMessages : public mutils::ByteRepresentable { vector> msgs; // vector of (nodeid, msg #) + TestState* test_state; public: - CookedMessages() = default; - CookedMessages(const vector>& msgs) : msgs(msgs) { - } + // Factory constructor + CookedMessages(TestState* test_state) : test_state(test_state) {} + // Deserialization constructor + CookedMessages(const vector>& msgs, TestState* test_state) + : msgs(msgs), test_state(test_state) {} void send(uint nodeid, uint msg) { msgs.push_back(std::make_pair(nodeid, msg)); + // Count the number of RPC messages received here + test_state->counter++; + test_state->check_test_done(); } vector> get_msgs(uint start_index, uint end_index) { @@ -32,13 +62,23 @@ class CookedMessages : public mutils::ByteRepresentable { return vector>(msgs.begin() + start_index, msgs.begin() + end_index); } - // default state - DEFAULT_SERIALIZATION_SUPPORT(CookedMessages, msgs); + DEFAULT_SERIALIZE(msgs); + static std::unique_ptr from_bytes(mutils::DeserializationManager* dsm, uint8_t const* buffer); + DEFAULT_DESERIALIZE_NOALLOC(CookedMessages); // what operations you want as part of the subgroup REGISTER_RPC_FUNCTIONS(CookedMessages, ORDERED_TARGETS(send, get_msgs)); }; +// Custom deserializer that retrieves the TestState pointer from the DeserializationManager +std::unique_ptr CookedMessages::from_bytes(mutils::DeserializationManager* dsm, uint8_t const* buffer) { + auto msgs_ptr = mutils::from_bytes>>(dsm, buffer); + assert(dsm); + assert(dsm->registered()); + TestState* test_state_ptr = &(dsm->mgr()); + return std::make_unique(*msgs_ptr, test_state_ptr); +} + bool verify_local_order(vector> msgs) { map order; for(auto [nodeid, msg] : msgs) { @@ -58,33 +98,19 @@ int main(int argc, char* argv[]) { } const uint32_t num_nodes = atoi(argv[1]); Conf::initialize(argc, argv); - auto subgroup_membership_function = [num_nodes]( - const std::vector& subgroup_type_order, - const std::unique_ptr& prev_view, derecho::View& curr_view) { - auto& members = curr_view.members; - auto num_members = members.size(); - if(num_members < num_nodes) { - throw subgroup_provisioning_exception(); - } - subgroup_shard_layout_t layout(num_members); - layout[0].push_back(curr_view.make_subview(vector(members))); - derecho::subgroup_allocation_map_t subgroup_allocation; - subgroup_allocation.emplace(std::type_index(typeid(CookedMessages)), std::move(layout)); - return subgroup_allocation; - }; - auto cooked_subgroup_factory = [](persistent::PersistentRegistry*, derecho::subgroup_id_t) { return std::make_unique(); }; + // Configure the default subgroup allocator to put all the nodes in one fixed-size subgroup + SubgroupInfo subgroup_info(derecho::DefaultSubgroupAllocator( + {{std::type_index(typeid(CookedMessages)), + derecho::one_subgroup_policy(derecho::fixed_even_shards(1, num_nodes))}})); - SubgroupInfo subgroup_info(subgroup_membership_function); - std::atomic done = false; uint32_t num_msgs = 500; - auto stability_callback = [num_msgs, num_nodes, &done, counter = (uint)0](subgroup_id_t, node_id_t sender_id, message_id_t index, std::optional>, persistent::version_t) mutable { - counter++; - if(counter == num_msgs * num_nodes) { - done = true; - } + TestState test_state(num_msgs, num_nodes); + auto cooked_subgroup_factory = [&](persistent::PersistentRegistry*, derecho::subgroup_id_t) { + return std::make_unique(&test_state); }; - Group group({stability_callback}, subgroup_info, {}, {}, cooked_subgroup_factory); + // Put a pointer to test_state in Group's vector of DeserializationContexts so it will be passed to CookedMessages + Group group({}, subgroup_info, {&test_state}, {}, cooked_subgroup_factory); cout << "Finished constructing/joining the group" << endl; auto group_members = group.get_members(); @@ -107,7 +133,7 @@ int main(int argc, char* argv[]) { for(uint i = 1; i < num_msgs + 1; ++i) { cookedMessagesHandle.ordered_send(my_rank, i); } - while(!done) { + while(!test_state.done) { } if(my_rank == 0) { uint32_t max_msg_size = getConfUInt64(Conf::SUBGROUP_DEFAULT_MAX_PAYLOAD_SIZE);