Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add persistent state implementation #19

Merged
merged 6 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(miniraft
src/messages.cpp
src/raft.cpp
src/server.cpp
src/persist.cpp
)

target_link_libraries(miniraft PUBLIC coroio)
Expand Down
20 changes: 15 additions & 5 deletions examples/kv.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <string_view>

#include <server.h>
#include <persist.h>

#include "kv.h"

Expand Down Expand Up @@ -73,7 +74,7 @@ TMessageHolder<TLogEntry> TKv::Prepare(TMessageHolder<TCommandRequest> command,
}

template<typename TPoller, typename TSocket>
NNet::TVoidTask Client(TPoller& poller, TSocket socket) {
NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
using TFileHandle = typename TPoller::TFileHandle;
TFileHandle input{0, poller}; // stdin
co_await socket.Connect();
Expand Down Expand Up @@ -144,7 +145,7 @@ NNet::TVoidTask Client(TPoller& poller, TSocket socket) {
}

void usage(const char* prog) {
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...]" << "\n";
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...] [--persist]" << "\n";
exit(0);
}

Expand All @@ -155,6 +156,7 @@ int main(int argc, char** argv) {
TNodeDict nodes;
uint32_t id = 0;
bool server = false;
bool persist = false;
for (int i = 1; i < argc; i++) {
if (!strcmp(argv[i], "--server")) {
server = true;
Expand All @@ -163,6 +165,8 @@ int main(int argc, char** argv) {
hosts.push_back(THost{argv[++i]});
} else if (!strcmp(argv[i], "--id") && i < argc - 1) {
id = atoi(argv[++i]);
} else if (!strcmp(argv[i], "--persist")) {
persist = true;
} else if (!strcmp(argv[i], "--help")) {
usage(argv[0]);
}
Expand Down Expand Up @@ -193,7 +197,11 @@ int main(int argc, char** argv) {
}

std::shared_ptr<IRsm> rsm = std::make_shared<TKv>();
auto raft = std::make_shared<TRaft>(rsm, std::make_shared<TState>(), myHost.Id, nodes);
std::shared_ptr<IState> state = std::make_shared<TState>();
if (persist) {
state = std::make_shared<TDiskState>("state", myHost.Id);
}
auto raft = std::make_shared<TRaft>(rsm, state, myHost.Id, nodes);
TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller());
socket.Bind();
socket.Listen();
Expand All @@ -204,8 +212,10 @@ int main(int argc, char** argv) {
NNet::TAddress addr{hosts[0].Address, hosts[0].Port};
NNet::TSocket socket(std::move(addr), loop.Poller());

Client(loop.Poller(), std::move(socket));
loop.Loop();
auto h = Client(loop.Poller(), std::move(socket));
while (!h.done()) {
loop.Step();
}
}

return 0;
Expand Down
114 changes: 114 additions & 0 deletions src/persist.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "persist.h"
#include <iostream>

TDiskState::TDiskState(const std::string& name, uint32_t id)
: Entries(Open(name + ".entries." + std::to_string(id)))
, Index(Open(name + ".index." + std::to_string(id)))
, State(Open(name + ".state." + std::to_string(id)))
{
// TODO: read all an validate
State.seekg(0, std::ios_base::end);
auto size = State.tellg();
if (size == sizeof(LastLogIndex) + sizeof(CurrentTerm) + sizeof(VotedFor)) {
State.seekg(0);
State.read((char*)&LastLogIndex, sizeof(LastLogIndex));
State.read((char*)&CurrentTerm, sizeof(CurrentTerm));
State.read((char*)&VotedFor, sizeof(VotedFor));
} else {
Commit();
}
if (LastLogIndex > 0) {
LastLogTerm = Get(LastLogIndex-1)->Term;
}
}

std::fstream TDiskState::Open(const std::string& name)
{
auto flags = std::ios::in | std::ios::out | std::ios::binary;
std::fstream f(name, flags);
if (!f.is_open()) {
std::ofstream tmp(name);
tmp.close();
f.open(name, flags);
}

if (!f.is_open()) {
throw std::runtime_error("Cannot open file: " + name);
}

return f;
}

TMessageHolder<TMessage> TDiskState::Read() const
{
decltype(TMessage::Type) type;
decltype(TMessage::Len) len;
if (!Entries.read((char*)&type, sizeof(type))) {
throw std::runtime_error("Error on read 1");
}
if (!Entries.read((char*)&len, sizeof(len))) {
throw std::runtime_error("Error on read 2");
}
auto mes = NewHoldedMessage<TMessage>(type, len);
if (!Entries.read((char*)mes->Value, len - sizeof(TMessage))) {
throw std::runtime_error("Error on read 3");
}
return mes;
}

void TDiskState::Write(TMessageHolder<TMessage> message)
{
Entries.write((char*)message.Mes, message->Len);
}

void TDiskState::RemoveLast()
{
if (LastLogIndex > 0) {
LastLogIndex--;
Commit();
}
}

void TDiskState::Append(TMessageHolder<TLogEntry> entry)
{
uint64_t offset = 0;
if (Get(LastLogIndex-1)) { // TODO: optimize
offset = Entries.tellg();
}

Write(entry);
Index.seekg(LastLogIndex * sizeof(offset));
Index.write((char*)&offset, sizeof(offset));
LastLogTerm = entry->Term;
LastLogIndex ++;
Commit();
}

TMessageHolder<TLogEntry> TDiskState::Get(int64_t index) const
{
if (index >= LastLogIndex || index < 0) {
return {};
}

uint64_t offset = 0;
Index.seekg(index * sizeof(offset));
if (!Index.read((char*)&offset, sizeof(offset))) {
return {};
}

Entries.seekg(offset);
auto entry = Read().Cast<TLogEntry>();
return entry;
}

void TDiskState::Commit()
{
State.seekg(0);
if (!State.write((char*)&LastLogIndex, sizeof(LastLogIndex))) { abort(); }
if (!State.write((char*)&CurrentTerm, sizeof(CurrentTerm))) { abort(); }
if (!State.write((char*)&VotedFor, sizeof(VotedFor))) { abort(); }
State.flush();
Index.flush();
Entries.flush();
}

24 changes: 24 additions & 0 deletions src/persist.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "state.h"

#include <fstream>
#include <string>

struct TDiskState : IState {
TDiskState(const std::string& name, uint32_t id);

void RemoveLast() override;
void Append(TMessageHolder<TLogEntry> entry) override;
TMessageHolder<TLogEntry> Get(int64_t index) const override;
void Commit() override;

private:
std::fstream Open(const std::string& name);
TMessageHolder<TMessage> Read() const;
void Write(TMessageHolder<TMessage>);

mutable std::fstream Entries;
mutable std::fstream Index;
std::fstream State;
};
1 change: 1 addition & 0 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state)
if (state.LogTerm(commitIndex) == state.CurrentTerm) {
CommitIndex = commitIndex;
}
// TODO: If state.LogTerm(commitIndex) < state.CurrentTerm need to append empty message to log
return *this;
}

Expand Down
60 changes: 1 addition & 59 deletions src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "messages.h"
#include "timesource.h"
#include "state.h"

struct INode {
virtual ~INode() = default;
Expand Down Expand Up @@ -37,65 +38,6 @@ struct TDummyRsm: public IRsm {

using TNodeDict = std::unordered_map<uint32_t, std::shared_ptr<INode>>;

struct IState {
uint64_t CurrentTerm = 1;
uint32_t VotedFor = 0;
uint64_t LastLogIndex = 0;
uint64_t LastLogTerm = 0;

virtual uint64_t LogTerm(int64_t index = -1) const = 0;
virtual void RemoveLast() = 0;
virtual void Append(TMessageHolder<TLogEntry>) = 0;
virtual TMessageHolder<TLogEntry> Get(int64_t index) = 0;
virtual void Commit() = 0;
virtual ~IState() = default;
};

struct TState: IState {
std::vector<TMessageHolder<TLogEntry>> Log;

TState() = default;
TState(uint64_t currentTerm, uint32_t votedFor, const std::vector<TMessageHolder<TLogEntry>>& log)
: Log(log)
{
CurrentTerm = currentTerm;
VotedFor = votedFor;
if (!log.empty()) {
LastLogIndex = log.size();
LastLogTerm = log.back()->Term;
}
}

void RemoveLast() override {
Log.pop_back();
LastLogIndex = Log.size();
LastLogTerm = Log.empty() ? 0 : Log.back()->Term;
}

void Append(TMessageHolder<TLogEntry> entry) override {
Log.emplace_back(std::move(entry));
LastLogIndex = Log.size();
LastLogTerm = Log.back()->Term;
}

TMessageHolder<TLogEntry> Get(int64_t index) override {
return Log[index];
}

void Commit() override { }

uint64_t LogTerm(int64_t index = -1) const override {
if (index < 0) {
index = Log.size();
}
if (index < 1 || index > Log.size()) {
return 0;
} else {
return Log[index-1]->Term;
}
}
};

struct TVolatileState {
uint64_t CommitIndex = 0;
uint64_t LastApplied = 0;
Expand Down
62 changes: 62 additions & 0 deletions src/state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include "messages.h"

struct IState {
uint64_t CurrentTerm = 1;
uint32_t VotedFor = 0;
uint64_t LastLogIndex = 0;
uint64_t LastLogTerm = 0;

virtual void RemoveLast() = 0;
virtual void Append(TMessageHolder<TLogEntry>) = 0;
virtual TMessageHolder<TLogEntry> Get(int64_t index) const = 0;
virtual void Commit() = 0;
virtual ~IState() = default;

uint64_t LogTerm(int64_t index = -1) const {
if (index < 0) {
index = LastLogIndex;
}
if (index < 1 || index > LastLogIndex) {
return 0;
} else {
return Get(index-1)->Term;
}
}
};

struct TState: IState {
std::vector<TMessageHolder<TLogEntry>> Log;

TState() = default;
TState(uint64_t currentTerm, uint32_t votedFor, const std::vector<TMessageHolder<TLogEntry>>& log)
: Log(log)
{
CurrentTerm = currentTerm;
VotedFor = votedFor;
if (!log.empty()) {
LastLogIndex = log.size();
LastLogTerm = log.back()->Term;
}
}

void RemoveLast() override {
Log.pop_back();
LastLogIndex = Log.size();
LastLogTerm = Log.empty() ? 0 : Log.back()->Term;
}

void Append(TMessageHolder<TLogEntry> entry) override {
Log.emplace_back(std::move(entry));
LastLogIndex = Log.size();
LastLogTerm = Log.back()->Term;
}

TMessageHolder<TLogEntry> Get(int64_t index) const override {
return Log[index];
}

void Commit() override { }
};

Loading
Loading