diff --git a/src/messages.h b/src/messages.h index 2b842dc..ac68c3a 100644 --- a/src/messages.h +++ b/src/messages.h @@ -5,6 +5,7 @@ #include #include +#include #include enum class EMessageType : uint32_t { @@ -211,6 +212,24 @@ TMessageHolder NewHoldedMessage(uint32_t size) { return NewHoldedMessage(static_cast(T::MessageType), size); } +template +TMessageHolder NewHoldedMessage(T t) { + auto m = NewHoldedMessage(static_cast(T::MessageType), sizeof(T)); + t.Type = T::MessageType; + t.Len = sizeof(T); + memcpy(m->Mes, &t, sizeof(T)); + return m; +} + +template +requires std::derived_from +TMessageHolder NewHoldedMessage(TMessageEx h, T t) { + auto m = NewHoldedMessage(); + memcpy((char*)m.Mes + sizeof(TMessage), (char*)&h + sizeof(TMessage), sizeof(h) - sizeof(TMessage)); + memcpy((char*)m.Mes + sizeof(TMessageEx), (char*)&t + sizeof(TMessageEx), sizeof(T) - sizeof(TMessageEx)); + return m; +} + inline TMessageHolder NewTimeout() { return NewHoldedMessage(static_cast(EMessageType::TIMEOUT), sizeof(TTimeout)); } diff --git a/src/raft.cpp b/src/raft.cpp index b9d0181..211dd79 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -93,11 +93,9 @@ std::unique_ptr TRaft::OnRequestVote(TMessageHolder(); - reply->Src = Id; - reply->Dst = message->Src; - reply->Term = State->CurrentTerm; - reply->VoteGranted = accept; + auto reply = NewHoldedMessage( + TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm}, + TRequestVoteResponse {.VoteGranted = accept}); return std::make_unique(TResult { .NextState = accept ? std::make_unique(TState{ @@ -196,12 +194,9 @@ std::unique_ptr TRaft::OnAppendEntries(TMessageHolderLeaderCommit); } - auto reply = NewHoldedMessage(); - reply->Src = Id; - reply->Dst = message->Src; - reply->Term = State->CurrentTerm; - reply->Success = success; - reply->MatchIndex = matchIndex; + auto reply = NewHoldedMessage( + TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm}, + TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success}); auto nextVolatileState = *VolatileState; nextVolatileState.SetCommitIndex(commitIndex); @@ -241,15 +236,13 @@ std::unique_ptr TRaft::OnAppendEntries(TMessageHolder TRaft::CreateVote() { - auto mes = NewHoldedMessage(); - mes->Src = Id; - mes->Dst = 0; - mes->Term = State->CurrentTerm+1; - mes->CandidateId = Id; - mes->LastLogIndex = State->Log.size(); - mes->LastLogTerm = State->Log.empty() - ? 0 - : State->Log.back()->Term; + auto mes = NewHoldedMessage( + TMessageEx {.Src = Id, .Dst = 0, .Term = State->CurrentTerm+1}, + TRequestVoteRequest { + .LastLogIndex = State->Log.size(), + .LastLogTerm = State->Log.empty() ? 0 : State->Log.back()->Term, + .CandidateId = Id, + }); return mes; } @@ -263,16 +256,15 @@ std::vector> TRaft::CreateAppendEntries() lastIndex = prevIndex; } - auto mes = NewHoldedMessage(); - - mes->Src = Id; - mes->Dst = nodeId; - mes->Term = State->CurrentTerm; - mes->LeaderId = Id; - mes->PrevLogIndex = prevIndex; - mes->PrevLogTerm = State->LogTerm(prevIndex); - mes->LeaderCommit = std::min(VolatileState->CommitIndex, lastIndex); - mes->Nentries = lastIndex - prevIndex; + auto mes = NewHoldedMessage( + TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm}, + TAppendEntriesRequest { + .PrevLogIndex = prevIndex, + .PrevLogTerm = State->LogTerm(prevIndex), + .LeaderCommit = std::min(VolatileState->CommitIndex, lastIndex), + .LeaderId = Id, + .Nentries = static_cast(lastIndex - prevIndex), + }); std::vector> payload; payload.reserve(lastIndex - prevIndex); for (auto i = prevIndex; i < lastIndex; i++) { diff --git a/src/raft.h b/src/raft.h index d910ee7..9afd88f 100644 --- a/src/raft.h +++ b/src/raft.h @@ -16,14 +16,14 @@ struct INode { virtual void Drain() = 0; }; -using TNodeDict = std::unordered_map>; +using TNodeDict = std::unordered_map>; struct TState { uint64_t CurrentTerm = 1; uint32_t VotedFor = 0; std::vector> Log; - int LogTerm(int index = -1) const { + uint64_t LogTerm(int64_t index = -1) const { if (index < 0) { index = Log.size(); }