Skip to content

Commit

Permalink
Forward write requests to leader (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius authored Dec 25, 2024
1 parent 3526f9e commit 6d2f912
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 6 deletions.
48 changes: 42 additions & 6 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ TVolatileState& TVolatileState::SetBackOff(uint32_t id, int size) {
return *this;
}

TVolatileState& TVolatileState::SetLeaderId(uint32_t id) {
LeaderId = id;
return *this;
}

TVolatileState& TVolatileState::SetCommitIndex(int index)
{
CommitIndex = index;
Expand Down Expand Up @@ -235,6 +240,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntries
TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success});

(*VolatileState)
.SetLeaderId(message->LeaderId)
.SetCommitIndex(commitIndex)
.SetElectionDue(MakeElection(now));

Expand Down Expand Up @@ -283,14 +289,41 @@ void TRaft::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std:
}
} else if (StateName == EState::FOLLOWER && replyTo) {
if (command->Flags & TCommandRequest::EWrite) {
// TODO: send error code
replyTo->Send(NewHoldedMessage(TCommandResponse {.Index = 0}));
if (command->Cookie) {
// already forwarded
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
return;
}

if (VolatileState->LeaderId == 0) {
// TODO: wait for state change
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
return;
}

assert(VolatileState->LeaderId != Id);
// Forward
command->Cookie = std::max<uint32_t>(1, ForwardCookie);
Nodes[VolatileState->LeaderId]->Send(std::move(command));
Forwarded[ForwardCookie] = replyTo;
ForwardCookie++;
} else {
Waiting.emplace(TWaiting{State->LastLogIndex, std::move(command), replyTo});
}
} else if (StateName == EState::CANDIDATE && replyTo) {
// wait
// TODO: wait for state change
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
}
}

void TRaft::OnCommandResponse(TMessageHolder<TCommandResponse> command) {
// forwarded
auto it = Forwarded.find(command->Cookie);
if (it == Forwarded.end()) {
return;
}
it->second->Send(std::move(command));
Forwarded.erase(it);
}

TMessageHolder<TRequestVoteRequest> TRaft::CreateVote(uint32_t nodeId) {
Expand Down Expand Up @@ -370,6 +403,8 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
// client request
if (auto maybeCommandRequest = message.Maybe<TCommandRequest>()) {
return OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo);
} else if (auto maybeCommandResponse = message.Maybe<TCommandResponse>()) {
return OnCommandResponse(std::move(maybeCommandResponse.Cast()));
}

if (message.IsEx()) {
Expand Down Expand Up @@ -420,18 +455,19 @@ void TRaft::ProcessWaiting() {
auto lastApplied = VolatileState->LastApplied;
while (!Waiting.empty() && Waiting.top().Index <= lastApplied) {
auto w = Waiting.top(); Waiting.pop();
TMessageHolder<TMessage> reply;
TMessageHolder<TCommandResponse> reply;
if (w.Command->Flags & TCommandRequest::EWrite) {
while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) {
WriteAnswers.pop();
}
assert(!WriteAnswers.empty());
auto answer = std::move(WriteAnswers.front()); WriteAnswers.pop();
assert(answer.Index == w.Index);
reply = std::move(answer.Reply);
reply = std::move(answer.Reply.Cast<TCommandResponse>());
} else {
reply = Rsm->Read(std::move(w.Command), w.Index);
reply = Rsm->Read(std::move(w.Command), w.Index).Cast<TCommandResponse>();
}
reply->Cookie = w.Command->Cookie;
w.ReplyTo->Send(std::move(reply));
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using TNodeDict = std::unordered_map<uint32_t, std::shared_ptr<INode>>;
struct TVolatileState {
uint64_t CommitIndex = 0;
uint64_t LastApplied = 0;
uint32_t LeaderId = 0;
std::unordered_map<uint32_t, uint64_t> NextIndex;
std::unordered_map<uint32_t, uint64_t> MatchIndex;
std::unordered_set<uint32_t> Votes;
Expand All @@ -63,6 +64,7 @@ struct TVolatileState {
TVolatileState& SetRpcDue(uint32_t id, ITimeSource::Time rpcDue);
TVolatileState& SetBatchSize(uint32_t id, int size);
TVolatileState& SetBackOff(uint32_t id, int size);
TVolatileState& SetLeaderId(uint32_t id);

bool operator==(const TVolatileState& other) const {
return CommitIndex == other.CommitIndex &&
Expand Down Expand Up @@ -128,6 +130,7 @@ class TRaft {
void OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message);

void OnCommandRequest(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);
void OnCommandResponse(TMessageHolder<TCommandResponse> message);

void LeaderTimeout(ITimeSource::Time now);
void CandidateTimeout(ITimeSource::Time now);
Expand Down Expand Up @@ -163,6 +166,8 @@ class TRaft {
TMessageHolder<TMessage> Reply;
};
std::queue<TAnswer> WriteAnswers;
uint32_t ForwardCookie = 1;
std::unordered_map<uint32_t, std::shared_ptr<INode>> Forwarded;

EState StateName;
uint32_t Seed = 31337;
Expand Down
25 changes: 25 additions & 0 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ NNet::TVoidTask TRaftServer<TSocket>::InboundConnection(TSocket socket) {
} catch (const std::exception & ex) {
std::cerr << "Exception: " << ex.what() << "\n";
}
// TODO: erase also from Forwarded
Nodes.erase(client);
co_return;
}
Expand All @@ -137,6 +138,13 @@ template<typename TSocket>
void TRaftServer<TSocket>::Serve() {
Idle();
InboundServe();
std::vector<NNet::TVoidTask> tasks;
for (auto& node : Nodes) {
auto realNode = std::dynamic_pointer_cast<TNode<TSocket>>(node);
if (realNode) {
tasks.emplace_back(OutboundServe(realNode));
}
}
}

template<typename TSocket>
Expand All @@ -146,6 +154,23 @@ void TRaftServer<TSocket>::DrainNodes() {
}
}

template<typename TSocket>
NNet::TVoidTask TRaftServer<TSocket>::OutboundServe(std::shared_ptr<TNode<TSocket>> node) {
// read forwarded replies
while (true) {
try {
auto mes = co_await TMessageReader(node->Sock()).Read();
// TODO: check message type
// TODO: should be only TCommandResponse
Raft->Process(TimeSource->Now(), std::move(mes), nullptr);
} catch (const std::exception& ex) {
std::cerr << "Exception: " << ex.what() << "\n";
}
co_await Poller.Sleep(std::chrono::milliseconds(1000));
}
co_return;
}

template<typename TSocket>
NNet::TVoidTask TRaftServer<TSocket>::InboundServe() {
while (true) {
Expand Down
2 changes: 2 additions & 0 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class TNode: public INode {

void Send(TMessageHolder<TMessage> message) override;
void Drain() override;
bool IsConnected() { return Connected; }
TSocket& Sock() {
return Socket;
}
Expand Down Expand Up @@ -134,6 +135,7 @@ class TRaftServer {
private:
NNet::TVoidTask InboundServe();
NNet::TVoidTask InboundConnection(TSocket socket);
NNet::TVoidTask OutboundServe(std::shared_ptr<TNode<TSocket>>);
NNet::TVoidTask Idle();
void DrainNodes();
void DebugPrint();
Expand Down

0 comments on commit 6d2f912

Please sign in to comment.