Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward write requests to leader #22

Merged
merged 3 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ TVolatileState& TVolatileState::SetBackOff(uint32_t id, int size) {
return *this;
}

TVolatileState& TVolatileState::SetLeaderId(uint32_t id) {
LeaderId = id;
return *this;
}

TVolatileState& TVolatileState::SetCommitIndex(int index)
{
CommitIndex = index;
Expand Down Expand Up @@ -235,6 +240,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntries
TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success});

(*VolatileState)
.SetLeaderId(message->LeaderId)
.SetCommitIndex(commitIndex)
.SetElectionDue(MakeElection(now));

Expand Down Expand Up @@ -283,14 +289,41 @@ void TRaft::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std:
}
} else if (StateName == EState::FOLLOWER && replyTo) {
if (command->Flags & TCommandRequest::EWrite) {
// TODO: send error code
replyTo->Send(NewHoldedMessage(TCommandResponse {.Index = 0}));
if (command->Cookie) {
// already forwarded
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
return;
}

if (VolatileState->LeaderId == 0) {
// TODO: wait for state change
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
return;
}

assert(VolatileState->LeaderId != Id);
// Forward
command->Cookie = std::max<uint32_t>(1, ForwardCookie);
Nodes[VolatileState->LeaderId]->Send(std::move(command));
Forwarded[ForwardCookie] = replyTo;
ForwardCookie++;
} else {
Waiting.emplace(TWaiting{State->LastLogIndex, std::move(command), replyTo});
}
} else if (StateName == EState::CANDIDATE && replyTo) {
// wait
// TODO: wait for state change
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
}
}

void TRaft::OnCommandResponse(TMessageHolder<TCommandResponse> command) {
// forwarded
auto it = Forwarded.find(command->Cookie);
if (it == Forwarded.end()) {
return;
}
it->second->Send(std::move(command));
Forwarded.erase(it);
}

TMessageHolder<TRequestVoteRequest> TRaft::CreateVote(uint32_t nodeId) {
Expand Down Expand Up @@ -370,6 +403,8 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
// client request
if (auto maybeCommandRequest = message.Maybe<TCommandRequest>()) {
return OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo);
} else if (auto maybeCommandResponse = message.Maybe<TCommandResponse>()) {
return OnCommandResponse(std::move(maybeCommandResponse.Cast()));
}

if (message.IsEx()) {
Expand Down Expand Up @@ -420,18 +455,19 @@ void TRaft::ProcessWaiting() {
auto lastApplied = VolatileState->LastApplied;
while (!Waiting.empty() && Waiting.top().Index <= lastApplied) {
auto w = Waiting.top(); Waiting.pop();
TMessageHolder<TMessage> reply;
TMessageHolder<TCommandResponse> reply;
if (w.Command->Flags & TCommandRequest::EWrite) {
while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) {
WriteAnswers.pop();
}
assert(!WriteAnswers.empty());
auto answer = std::move(WriteAnswers.front()); WriteAnswers.pop();
assert(answer.Index == w.Index);
reply = std::move(answer.Reply);
reply = std::move(answer.Reply.Cast<TCommandResponse>());
} else {
reply = Rsm->Read(std::move(w.Command), w.Index);
reply = Rsm->Read(std::move(w.Command), w.Index).Cast<TCommandResponse>();
}
reply->Cookie = w.Command->Cookie;
w.ReplyTo->Send(std::move(reply));
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using TNodeDict = std::unordered_map<uint32_t, std::shared_ptr<INode>>;
struct TVolatileState {
uint64_t CommitIndex = 0;
uint64_t LastApplied = 0;
uint32_t LeaderId = 0;
std::unordered_map<uint32_t, uint64_t> NextIndex;
std::unordered_map<uint32_t, uint64_t> MatchIndex;
std::unordered_set<uint32_t> Votes;
Expand All @@ -63,6 +64,7 @@ struct TVolatileState {
TVolatileState& SetRpcDue(uint32_t id, ITimeSource::Time rpcDue);
TVolatileState& SetBatchSize(uint32_t id, int size);
TVolatileState& SetBackOff(uint32_t id, int size);
TVolatileState& SetLeaderId(uint32_t id);

bool operator==(const TVolatileState& other) const {
return CommitIndex == other.CommitIndex &&
Expand Down Expand Up @@ -128,6 +130,7 @@ class TRaft {
void OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message);

void OnCommandRequest(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);
void OnCommandResponse(TMessageHolder<TCommandResponse> message);

void LeaderTimeout(ITimeSource::Time now);
void CandidateTimeout(ITimeSource::Time now);
Expand Down Expand Up @@ -163,6 +166,8 @@ class TRaft {
TMessageHolder<TMessage> Reply;
};
std::queue<TAnswer> WriteAnswers;
uint32_t ForwardCookie = 1;
std::unordered_map<uint32_t, std::shared_ptr<INode>> Forwarded;

EState StateName;
uint32_t Seed = 31337;
Expand Down
25 changes: 25 additions & 0 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ NNet::TVoidTask TRaftServer<TSocket>::InboundConnection(TSocket socket) {
} catch (const std::exception & ex) {
std::cerr << "Exception: " << ex.what() << "\n";
}
// TODO: erase also from Forwarded
Nodes.erase(client);
co_return;
}
Expand All @@ -137,6 +138,13 @@ template<typename TSocket>
void TRaftServer<TSocket>::Serve() {
Idle();
InboundServe();
std::vector<NNet::TVoidTask> tasks;
for (auto& node : Nodes) {
auto realNode = std::dynamic_pointer_cast<TNode<TSocket>>(node);
if (realNode) {
tasks.emplace_back(OutboundServe(realNode));
}
}
}

template<typename TSocket>
Expand All @@ -146,6 +154,23 @@ void TRaftServer<TSocket>::DrainNodes() {
}
}

template<typename TSocket>
NNet::TVoidTask TRaftServer<TSocket>::OutboundServe(std::shared_ptr<TNode<TSocket>> node) {
// read forwarded replies
while (true) {
try {
auto mes = co_await TMessageReader(node->Sock()).Read();
// TODO: check message type
// TODO: should be only TCommandResponse
Raft->Process(TimeSource->Now(), std::move(mes), nullptr);
} catch (const std::exception& ex) {
std::cerr << "Exception: " << ex.what() << "\n";
}
co_await Poller.Sleep(std::chrono::milliseconds(1000));
}
co_return;
}

template<typename TSocket>
NNet::TVoidTask TRaftServer<TSocket>::InboundServe() {
while (true) {
Expand Down
2 changes: 2 additions & 0 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class TNode: public INode {

void Send(TMessageHolder<TMessage> message) override;
void Drain() override;
bool IsConnected() { return Connected; }
TSocket& Sock() {
return Socket;
}
Expand Down Expand Up @@ -134,6 +135,7 @@ class TRaftServer {
private:
NNet::TVoidTask InboundServe();
NNet::TVoidTask InboundConnection(TSocket socket);
NNet::TVoidTask OutboundServe(std::shared_ptr<TNode<TSocket>>);
NNet::TVoidTask Idle();
void DrainNodes();
void DebugPrint();
Expand Down
Loading