diff --git a/src/raft.cpp b/src/raft.cpp index 400a759..7289bac 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -124,6 +124,11 @@ TVolatileState& TVolatileState::SetBackOff(uint32_t id, int size) { return *this; } +TVolatileState& TVolatileState::SetLeaderId(uint32_t id) { + LeaderId = id; + return *this; +} + TVolatileState& TVolatileState::SetCommitIndex(int index) { CommitIndex = index; @@ -235,6 +240,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolderLeaderId) .SetCommitIndex(commitIndex) .SetElectionDue(MakeElection(now)); @@ -283,14 +289,41 @@ void TRaft::OnCommandRequest(TMessageHolder command, const std: } } else if (StateName == EState::FOLLOWER && replyTo) { if (command->Flags & TCommandRequest::EWrite) { - // TODO: send error code - replyTo->Send(NewHoldedMessage(TCommandResponse {.Index = 0})); + if (command->Cookie) { + // already forwarded + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; + } + + if (VolatileState->LeaderId == 0) { + // TODO: wait for state change + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; + } + + assert(VolatileState->LeaderId != Id); + // Forward + command->Cookie = std::max(1, ForwardCookie); + Nodes[VolatileState->LeaderId]->Send(std::move(command)); + Forwarded[ForwardCookie] = replyTo; + ForwardCookie++; } else { Waiting.emplace(TWaiting{State->LastLogIndex, std::move(command), replyTo}); } } else if (StateName == EState::CANDIDATE && replyTo) { - // wait + // TODO: wait for state change + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + } +} + +void TRaft::OnCommandResponse(TMessageHolder command) { + // forwarded + auto it = Forwarded.find(command->Cookie); + if (it == Forwarded.end()) { + return; } + it->second->Send(std::move(command)); + Forwarded.erase(it); } TMessageHolder TRaft::CreateVote(uint32_t nodeId) { @@ -370,6 +403,8 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder message, con // client request if (auto maybeCommandRequest = message.Maybe()) { return OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo); + } else if (auto maybeCommandResponse = message.Maybe()) { + return OnCommandResponse(std::move(maybeCommandResponse.Cast())); } if (message.IsEx()) { @@ -420,7 +455,7 @@ void TRaft::ProcessWaiting() { auto lastApplied = VolatileState->LastApplied; while (!Waiting.empty() && Waiting.top().Index <= lastApplied) { auto w = Waiting.top(); Waiting.pop(); - TMessageHolder reply; + TMessageHolder reply; if (w.Command->Flags & TCommandRequest::EWrite) { while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) { WriteAnswers.pop(); @@ -428,10 +463,11 @@ void TRaft::ProcessWaiting() { assert(!WriteAnswers.empty()); auto answer = std::move(WriteAnswers.front()); WriteAnswers.pop(); assert(answer.Index == w.Index); - reply = std::move(answer.Reply); + reply = std::move(answer.Reply.Cast()); } else { - reply = Rsm->Read(std::move(w.Command), w.Index); + reply = Rsm->Read(std::move(w.Command), w.Index).Cast(); } + reply->Cookie = w.Command->Cookie; w.ReplyTo->Send(std::move(reply)); } } diff --git a/src/raft.h b/src/raft.h index f068a52..65b6c39 100644 --- a/src/raft.h +++ b/src/raft.h @@ -41,6 +41,7 @@ using TNodeDict = std::unordered_map>; struct TVolatileState { uint64_t CommitIndex = 0; uint64_t LastApplied = 0; + uint32_t LeaderId = 0; std::unordered_map NextIndex; std::unordered_map MatchIndex; std::unordered_set Votes; @@ -63,6 +64,7 @@ struct TVolatileState { TVolatileState& SetRpcDue(uint32_t id, ITimeSource::Time rpcDue); TVolatileState& SetBatchSize(uint32_t id, int size); TVolatileState& SetBackOff(uint32_t id, int size); + TVolatileState& SetLeaderId(uint32_t id); bool operator==(const TVolatileState& other) const { return CommitIndex == other.CommitIndex && @@ -128,6 +130,7 @@ class TRaft { void OnAppendEntries(TMessageHolder message); void OnCommandRequest(TMessageHolder message, const std::shared_ptr& replyTo); + void OnCommandResponse(TMessageHolder message); void LeaderTimeout(ITimeSource::Time now); void CandidateTimeout(ITimeSource::Time now); @@ -163,6 +166,8 @@ class TRaft { TMessageHolder Reply; }; std::queue WriteAnswers; + uint32_t ForwardCookie = 1; + std::unordered_map> Forwarded; EState StateName; uint32_t Seed = 31337; diff --git a/src/server.cpp b/src/server.cpp index 1e607b1..97f21da 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -129,6 +129,7 @@ NNet::TVoidTask TRaftServer::InboundConnection(TSocket socket) { } catch (const std::exception & ex) { std::cerr << "Exception: " << ex.what() << "\n"; } + // TODO: erase also from Forwarded Nodes.erase(client); co_return; } @@ -137,6 +138,13 @@ template void TRaftServer::Serve() { Idle(); InboundServe(); + std::vector tasks; + for (auto& node : Nodes) { + auto realNode = std::dynamic_pointer_cast>(node); + if (realNode) { + tasks.emplace_back(OutboundServe(realNode)); + } + } } template @@ -146,6 +154,23 @@ void TRaftServer::DrainNodes() { } } +template +NNet::TVoidTask TRaftServer::OutboundServe(std::shared_ptr> node) { + // read forwarded replies + while (true) { + try { + auto mes = co_await TMessageReader(node->Sock()).Read(); + // TODO: check message type + // TODO: should be only TCommandResponse + Raft->Process(TimeSource->Now(), std::move(mes), nullptr); + } catch (const std::exception& ex) { + std::cerr << "Exception: " << ex.what() << "\n"; + } + co_await Poller.Sleep(std::chrono::milliseconds(1000)); + } + co_return; +} + template NNet::TVoidTask TRaftServer::InboundServe() { while (true) { diff --git a/src/server.h b/src/server.h index d4c56a7..570f950 100644 --- a/src/server.h +++ b/src/server.h @@ -87,6 +87,7 @@ class TNode: public INode { void Send(TMessageHolder message) override; void Drain() override; + bool IsConnected() { return Connected; } TSocket& Sock() { return Socket; } @@ -134,6 +135,7 @@ class TRaftServer { private: NNet::TVoidTask InboundServe(); NNet::TVoidTask InboundConnection(TSocket socket); + NNet::TVoidTask OutboundServe(std::shared_ptr>); NNet::TVoidTask Idle(); void DrainNodes(); void DebugPrint();