Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Nov 29, 2023
1 parent 9cb5b96 commit 12c3925
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 32 deletions.
19 changes: 19 additions & 0 deletions src/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <assert.h>
#include <stdint.h>
#include <string.h>
#include <typeinfo>

enum class EMessageType : uint32_t {
Expand Down Expand Up @@ -211,6 +212,24 @@ TMessageHolder<T> NewHoldedMessage(uint32_t size) {
return NewHoldedMessage<T>(static_cast<uint32_t>(T::MessageType), size);
}

template<typename T>
TMessageHolder<T> NewHoldedMessage(T t) {
auto m = NewHoldedMessage<T>(static_cast<uint32_t>(T::MessageType), sizeof(T));
t.Type = T::MessageType;
t.Len = sizeof(T);
memcpy(m->Mes, &t, sizeof(T));
return m;
}

template<typename T>
requires std::derived_from<T, TMessageEx>
TMessageHolder<T> NewHoldedMessage(TMessageEx h, T t) {
auto m = NewHoldedMessage<T>();
memcpy((char*)m.Mes + sizeof(TMessage), (char*)&h + sizeof(TMessage), sizeof(h) - sizeof(TMessage));
memcpy((char*)m.Mes + sizeof(TMessageEx), (char*)&t + sizeof(TMessageEx), sizeof(T) - sizeof(TMessageEx));
return m;
}

inline TMessageHolder<TTimeout> NewTimeout() {
return NewHoldedMessage<TTimeout>(static_cast<uint32_t>(EMessageType::TIMEOUT), sizeof(TTimeout));
}
52 changes: 22 additions & 30 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,9 @@ std::unique_ptr<TResult> TRaft::OnRequestVote(TMessageHolder<TRequestVoteRequest
}
}

auto reply = NewHoldedMessage<TRequestVoteResponse>();
reply->Src = Id;
reply->Dst = message->Src;
reply->Term = State->CurrentTerm;
reply->VoteGranted = accept;
auto reply = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm},
TRequestVoteResponse {.VoteGranted = accept});

return std::make_unique<TResult>(TResult {
.NextState = accept ? std::make_unique<TState>(TState{
Expand Down Expand Up @@ -196,12 +194,9 @@ std::unique_ptr<TResult> TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesReq
commitIndex = std::max(commitIndex, message->LeaderCommit);
}

auto reply = NewHoldedMessage<TAppendEntriesResponse>();
reply->Src = Id;
reply->Dst = message->Src;
reply->Term = State->CurrentTerm;
reply->Success = success;
reply->MatchIndex = matchIndex;
auto reply = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm},
TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success});

auto nextVolatileState = *VolatileState;
nextVolatileState.SetCommitIndex(commitIndex);
Expand Down Expand Up @@ -241,15 +236,13 @@ std::unique_ptr<TResult> TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesRes
}

TMessageHolder<TRequestVoteRequest> TRaft::CreateVote() {
auto mes = NewHoldedMessage<TRequestVoteRequest>();
mes->Src = Id;
mes->Dst = 0;
mes->Term = State->CurrentTerm+1;
mes->CandidateId = Id;
mes->LastLogIndex = State->Log.size();
mes->LastLogTerm = State->Log.empty()
? 0
: State->Log.back()->Term;
auto mes = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = 0, .Term = State->CurrentTerm+1},
TRequestVoteRequest {
.LastLogIndex = State->Log.size(),
.LastLogTerm = State->Log.empty() ? 0 : State->Log.back()->Term,
.CandidateId = Id,
});
return mes;
}

Expand All @@ -263,16 +256,15 @@ std::vector<TMessageHolder<TAppendEntriesRequest>> TRaft::CreateAppendEntries()
lastIndex = prevIndex;
}

auto mes = NewHoldedMessage<TAppendEntriesRequest>();

mes->Src = Id;
mes->Dst = nodeId;
mes->Term = State->CurrentTerm;
mes->LeaderId = Id;
mes->PrevLogIndex = prevIndex;
mes->PrevLogTerm = State->LogTerm(prevIndex);
mes->LeaderCommit = std::min(VolatileState->CommitIndex, lastIndex);
mes->Nentries = lastIndex - prevIndex;
auto mes = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm},
TAppendEntriesRequest {
.PrevLogIndex = prevIndex,
.PrevLogTerm = State->LogTerm(prevIndex),
.LeaderCommit = std::min(VolatileState->CommitIndex, lastIndex),
.LeaderId = Id,
.Nentries = static_cast<uint32_t>(lastIndex - prevIndex),
});
std::vector<TMessageHolder<TMessage>> payload;
payload.reserve(lastIndex - prevIndex);
for (auto i = prevIndex; i < lastIndex; i++) {
Expand Down
4 changes: 2 additions & 2 deletions src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ struct INode {
virtual void Drain() = 0;
};

using TNodeDict = std::unordered_map<int, std::shared_ptr<INode>>;
using TNodeDict = std::unordered_map<uint32_t, std::shared_ptr<INode>>;

struct TState {
uint64_t CurrentTerm = 1;
uint32_t VotedFor = 0;
std::vector<TMessageHolder<TLogEntry>> Log;

int LogTerm(int index = -1) const {
uint64_t LogTerm(int64_t index = -1) const {
if (index < 0) {
index = Log.size();
}
Expand Down

0 comments on commit 12c3925

Please sign in to comment.