Skip to content

Commit

Permalink
Persistent state (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius authored Dec 24, 2024
1 parent f09bf6e commit 455e681
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 72 deletions.
2 changes: 1 addition & 1 deletion examples/kv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ int main(int argc, char** argv) {
}

std::shared_ptr<IRsm> rsm = std::make_shared<TKv>();
auto raft = std::make_shared<TRaft>(rsm, myHost.Id, nodes);
auto raft = std::make_shared<TRaft>(rsm, std::make_shared<TState>(), myHost.Id, nodes);
TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller());
socket.Bind();
socket.Listen();
Expand Down
2 changes: 1 addition & 1 deletion server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ int main(int argc, char** argv) {
}

std::shared_ptr<IRsm> rsm = std::make_shared<TDummyRsm>();
auto raft = std::make_shared<TRaft>(rsm, myHost.Id, nodes);
auto raft = std::make_shared<TRaft>(rsm, std::make_shared<TState>(), myHost.Id, nodes);
TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller());
socket.Bind();
socket.Listen();
Expand Down
41 changes: 21 additions & 20 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ TVolatileState& TVolatileState::Vote(uint32_t nodeId)
return *this;
}

TVolatileState& TVolatileState::CommitAdvance(int nservers, const TState& state)
TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state)
{
auto lastIndex = state.Log.size();
auto lastIndex = state.LastLogIndex;
Indices.clear(); Indices.reserve(nservers);
for (auto [_, index] : MatchIndex) {
Indices.push_back(index);
Expand Down Expand Up @@ -128,14 +128,14 @@ TVolatileState& TVolatileState::SetCommitIndex(int index)
return *this;
}

TRaft::TRaft(std::shared_ptr<IRsm> rsm, int node, const TNodeDict& nodes)
TRaft::TRaft(std::shared_ptr<IRsm> rsm, std::shared_ptr<IState> state, int node, const TNodeDict& nodes)
: Rsm(rsm)
, Id(node)
, Nodes(nodes)
, MinVotes((nodes.size()+2+nodes.size()%2)/2)
, Npeers(nodes.size())
, Nservers(nodes.size()+1)
, State(std::make_unique<TState>())
, State(std::move(state))
, VolatileState(std::make_unique<TVolatileState>())
, StateName(EState::FOLLOWER)
{
Expand All @@ -155,7 +155,7 @@ void TRaft::OnRequestVote(ITimeSource::Time now, TMessageHolder<TRequestVoteRequ
if (State->VotedFor == 0 || State->VotedFor == message->CandidateId) {
if (message->LastLogTerm > State->LogTerm()) {
accept = true;
} else if (message->LastLogTerm == State->LogTerm() && message->LastLogIndex >= State->Log.size()) {
} else if (message->LastLogTerm == State->LogTerm() && message->LastLogIndex >= State->LastLogIndex) {
accept = true;
}
}
Expand All @@ -167,6 +167,7 @@ void TRaft::OnRequestVote(ITimeSource::Time now, TMessageHolder<TRequestVoteRequ
if (accept) {
VolatileState->ElectionDue = MakeElection(now);
State->VotedFor = message->CandidateId;
State->Commit();
}

Nodes[reply->Dst]->Send(std::move(reply));
Expand Down Expand Up @@ -205,22 +206,21 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntries
uint64_t commitIndex = VolatileState->CommitIndex;
bool success = false;
if (message->PrevLogIndex == 0 ||
(message->PrevLogIndex <= State->Log.size()
(message->PrevLogIndex <= State->LastLogIndex
&& State->LogTerm(message->PrevLogIndex) == message->PrevLogTerm))
{
success = true;
auto index = message->PrevLogIndex;
auto& log = State->Log;
for (uint32_t i = 0 ; i < message.PayloadSize; i++) {
auto& data = message.Payload[i];
auto entry = data.Cast<TLogEntry>();
index++;
// replace or append log entries
if (State->LogTerm(index) != entry->Term) {
while (log.size() > index-1) {
log.pop_back();
while (State->LastLogIndex > index-1) {
State->RemoveLast();
}
log.push_back(entry);
State->Append(entry);
}
}

Expand Down Expand Up @@ -269,12 +269,11 @@ void TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message) {
}

void TRaft::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
auto& log = State->Log;
if (command->Flags & TCommandRequest::EWrite) {
auto entry = Rsm->Prepare(command, State->CurrentTerm);
log.emplace_back(std::move(entry));
State->Append(std::move(entry));
}
auto index = log.size();
auto index = State->LastLogIndex;
if (replyTo) {
waiting.emplace(TWaiting{index, std::move(command), replyTo});
}
Expand All @@ -284,8 +283,8 @@ TMessageHolder<TRequestVoteRequest> TRaft::CreateVote(uint32_t nodeId) {
auto mes = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm},
TRequestVoteRequest {
.LastLogIndex = State->Log.size(),
.LastLogTerm = State->Log.empty() ? 0 : State->Log.back()->Term,
.LastLogIndex = State->LastLogIndex,
.LastLogTerm = State->LastLogTerm,
.CandidateId = Id,
});
return mes;
Expand All @@ -294,7 +293,7 @@ TMessageHolder<TRequestVoteRequest> TRaft::CreateVote(uint32_t nodeId) {
TMessageHolder<TAppendEntriesRequest> TRaft::CreateAppendEntries(uint32_t nodeId) {
int batchSize = std::max(1, VolatileState->BatchSize[nodeId]);
auto prevIndex = VolatileState->NextIndex[nodeId] - 1;
auto lastIndex = std::min(prevIndex+batchSize, (uint64_t)State->Log.size());
auto lastIndex = std::min<uint64_t>(prevIndex+batchSize, State->LastLogIndex);
if (VolatileState->MatchIndex[nodeId]+1 < VolatileState->NextIndex[nodeId]) {
lastIndex = prevIndex;
}
Expand All @@ -313,7 +312,7 @@ TMessageHolder<TAppendEntriesRequest> TRaft::CreateAppendEntries(uint32_t nodeId
mes.InitPayload(lastIndex - prevIndex);
uint32_t j = 0;
for (auto i = prevIndex; i < lastIndex; i++) {
mes.Payload[j++] = State->Log[i];
mes.Payload[j++] = State->Get(i);
}
}
return mes;
Expand Down Expand Up @@ -361,6 +360,7 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
if (messageEx->Term > State->CurrentTerm) {
State->CurrentTerm = messageEx->Term;
State->VotedFor = 0;
State->Commit();
StateName = EState::FOLLOWER;
if (VolatileState->ElectionDue <= now || VolatileState->ElectionDue == ITimeSource::Max) {
VolatileState->ElectionDue = MakeElection(now);
Expand All @@ -385,7 +385,7 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
void TRaft::ProcessCommitted() {
auto commitIndex = VolatileState->CommitIndex;
for (auto i = VolatileState->LastApplied+1; i <= commitIndex; i++) {
Rsm->Write(State->Log[i-1], i);
Rsm->Write(State->Get(i-1), i);
}
VolatileState->LastApplied = commitIndex;
}
Expand Down Expand Up @@ -424,7 +424,7 @@ void TRaft::CandidateTimeout(ITimeSource::Time now) {
void TRaft::LeaderTimeout(ITimeSource::Time now) {
for (auto& [id, node] : Nodes) {
if (VolatileState->HeartbeatDue[id] <= now
|| (VolatileState->NextIndex[id] <= State->Log.size() &&
|| (VolatileState->NextIndex[id] <= State->LastLogIndex &&
VolatileState->RpcDue[id] <= now))
{
VolatileState->HeartbeatDue[id] = now + TTimeout::Election / 2;
Expand Down Expand Up @@ -453,14 +453,15 @@ void TRaft::ProcessTimeout(ITimeSource::Time now) {
VolatileState = std::move(nextVolatileState);
State->VotedFor = Id;
State->CurrentTerm ++;
State->Commit();
Become(EState::CANDIDATE);
}
}

if (StateName == EState::CANDIDATE) {
int nvotes = VolatileState->Votes.size()+1;
if (nvotes >= MinVotes) {
auto value = State->Log.size()+1;
auto value = State->LastLogIndex+1;
decltype(VolatileState->NextIndex) nextIndex;
decltype(VolatileState->RpcDue) rpcDue;
for (auto [id, _] : Nodes) {
Expand Down
64 changes: 51 additions & 13 deletions src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,54 @@ struct TDummyRsm: public IRsm {

using TNodeDict = std::unordered_map<uint32_t, std::shared_ptr<INode>>;

struct TState {
struct IState {
uint64_t CurrentTerm = 1;
uint32_t VotedFor = 0;
uint64_t LastLogIndex = 0;
uint64_t LastLogTerm = 0;

virtual uint64_t LogTerm(int64_t index = -1) const = 0;
virtual void RemoveLast() = 0;
virtual void Append(TMessageHolder<TLogEntry>) = 0;
virtual TMessageHolder<TLogEntry> Get(int64_t index) = 0;
virtual void Commit() = 0;
virtual ~IState() = default;
};

struct TState: IState {
std::vector<TMessageHolder<TLogEntry>> Log;

uint64_t LogTerm(int64_t index = -1) const {
TState() = default;
TState(uint64_t currentTerm, uint32_t votedFor, const std::vector<TMessageHolder<TLogEntry>>& log)
: Log(log)
{
CurrentTerm = currentTerm;
VotedFor = votedFor;
if (!log.empty()) {
LastLogIndex = log.size();
LastLogTerm = log.back()->Term;
}
}

void RemoveLast() override {
Log.pop_back();
LastLogIndex = Log.size();
LastLogTerm = Log.empty() ? 0 : Log.back()->Term;
}

void Append(TMessageHolder<TLogEntry> entry) override {
Log.emplace_back(std::move(entry));
LastLogIndex = Log.size();
LastLogTerm = Log.back()->Term;
}

TMessageHolder<TLogEntry> Get(int64_t index) override {
return Log[index];
}

void Commit() override { }

uint64_t LogTerm(int64_t index = -1) const override {
if (index < 0) {
index = Log.size();
}
Expand Down Expand Up @@ -70,7 +112,7 @@ struct TVolatileState {

TVolatileState& Vote(uint32_t id);
TVolatileState& SetLastApplied(int index);
TVolatileState& CommitAdvance(int nservers, const TState& state);
TVolatileState& CommitAdvance(int nservers, const IState& state);
TVolatileState& SetCommitIndex(int index);
TVolatileState& SetElectionDue(ITimeSource::Time);
TVolatileState& SetNextIndex(uint32_t id, uint64_t nextIndex);
Expand All @@ -90,26 +132,22 @@ enum class EState: int {

class TRaft {
public:
TRaft(std::shared_ptr<IRsm> rsm, int node, const TNodeDict& nodes);
TRaft(std::shared_ptr<IRsm> rsm, std::shared_ptr<IState> state, int node, const TNodeDict& nodes);

void Process(ITimeSource::Time now, TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo = {});
void ProcessTimeout(ITimeSource::Time now);

// ut
const auto GetState() const {
return State;
}

EState CurrentStateName() const {
return StateName;
}

void Become(EState newStateName);

const TState* GetState() const {
return State.get();
}

void SetState(const TState& state) {
*State = state;
}

const TVolatileState* GetVolatileState() const {
return VolatileState.get();
}
Expand Down Expand Up @@ -157,7 +195,7 @@ class TRaft {
int MinVotes;
int Npeers;
int Nservers;
std::unique_ptr<TState> State;
std::shared_ptr<IState> State;
std::unique_ptr<TVolatileState> VolatileState;

struct TWaiting {
Expand Down
10 changes: 5 additions & 5 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,17 @@ NNet::TVoidTask TRaftServer<TSocket>::InboundServe() {

template<typename TSocket>
void TRaftServer<TSocket>::DebugPrint() {
auto* state = Raft->GetState();
auto state = Raft->GetState();
auto* volatileState = Raft->GetVolatileState();
if (Raft->CurrentStateName() == EState::LEADER) {
std::cout << "Leader, "
<< "Term: " << state->CurrentTerm << ", "
<< "Index: " << state->Log.size() << ", "
<< "Index: " << state->LastLogIndex << ", "
<< "CommitIndex: " << volatileState->CommitIndex << ", "
<< "LastApplied: " << volatileState->LastApplied << ", ";
std::cout << "Delay: ";
for (auto [id, index] : volatileState->MatchIndex) {
std::cout << id << ":" << (state->Log.size() - index) << " ";
std::cout << id << ":" << (state->LastLogIndex - index) << " ";
}
std::cout << "MatchIndex: ";
for (auto [id, index] : volatileState->MatchIndex) {
Expand All @@ -182,14 +182,14 @@ void TRaftServer<TSocket>::DebugPrint() {
} else if (Raft->CurrentStateName() == EState::CANDIDATE) {
std::cout << "Candidate, "
<< "Term: " << state->CurrentTerm << ", "
<< "Index: " << state->Log.size() << ", "
<< "Index: " << state->LastLogIndex << ", "
<< "CommitIndex: " << volatileState->CommitIndex << ", "
<< "LastApplied: " << volatileState->LastApplied << ", "
<< "\n";
} else if (Raft->CurrentStateName() == EState::FOLLOWER) {
std::cout << "Follower, "
<< "Term: " << state->CurrentTerm << ", "
<< "Index: " << state->Log.size() << ", "
<< "Index: " << state->LastLogIndex << ", "
<< "CommitIndex: " << volatileState->CommitIndex << ", "
<< "LastApplied: " << volatileState->LastApplied << ", "
<< "\n";
Expand Down
Loading

0 comments on commit 455e681

Please sign in to comment.