diff --git a/CMakeLists.txt b/CMakeLists.txt index 428c7ba..b00cee4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ set(CMAKE_CXX_STANDARD 20) find_package(PkgConfig REQUIRED) pkg_check_modules(CMOCKA REQUIRED cmocka) +find_package(SQLite3) add_subdirectory(coroio) @@ -35,6 +36,12 @@ target_link_libraries(server miniraft coroio) target_link_libraries(client miniraft coroio) target_link_libraries(kv miniraft coroio) +if (SQLite3_FOUND) +add_executable(sql examples/sql.cpp) +target_include_directories(sql PRIVATE ${SQLite3_INCLUDE_DIRS}) +target_link_libraries(sql PRIVATE ${SQLite3_LIBRARIES} miniraft coroio) +endif() + target_include_directories(test_raft PRIVATE ${CMOCKA_INCLUDE_DIRS}) target_link_directories(test_raft PRIVATE ${CMOCKA_LIBRARY_DIRS}) target_link_libraries(test_raft miniraft coroio ${CMOCKA_LIBRARIES}) diff --git a/examples/kv.cpp b/examples/kv.cpp index 70642c2..d5935aa 100644 --- a/examples/kv.cpp +++ b/examples/kv.cpp @@ -51,7 +51,7 @@ TMessageHolder TKv::Read(TMessageHolder message, uint } } -void TKv::Write(TMessageHolder message, uint64_t index) { +TMessageHolder TKv::Write(TMessageHolder message, uint64_t index) { if (LastAppliedIndex < index) { auto entry = message.Cast(); std::string_view k(entry->TKvEntry::Data, entry->KeySize); @@ -63,6 +63,7 @@ void TKv::Write(TMessageHolder message, uint64_t index) { } LastAppliedIndex = index; } + return {}; } TMessageHolder TKv::Prepare(TMessageHolder command, uint64_t term) { diff --git a/examples/kv.h b/examples/kv.h index fbde002..5df1bb7 100644 --- a/examples/kv.h +++ b/examples/kv.h @@ -8,7 +8,7 @@ class TKv: public IRsm { public: TMessageHolder Read(TMessageHolder message, uint64_t index) override; - void Write(TMessageHolder message, uint64_t index) override; + TMessageHolder Write(TMessageHolder message, uint64_t index) override; TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; private: diff --git a/examples/sql.cpp b/examples/sql.cpp new file mode 100644 index 0000000..b447834 --- /dev/null +++ b/examples/sql.cpp @@ -0,0 +1,323 @@ +#include +#include + +#include +#include +#include + +struct TSqlEntry { + uint32_t QuerySize = 0; + char Query[0]; +}; + +struct TSqlLogEntry: public TLogEntry, public TSqlEntry +{ +}; + +struct TWriteSql: public TCommandRequest, public TSqlEntry +{ +}; + +struct TReadSql: public TCommandRequest, public TSqlEntry +{ +}; + +struct TRow { + std::vector> Values; +}; + +struct TResult { + std::vector Cols; + std::vector Rows; + + void Clear() { + Cols.clear(); + Rows.clear(); + } + + bool Empty() { + return Cols.empty(); + } +}; + +class TSql: public IRsm { +public: + TSql(const std::string& path, int serverId); + ~TSql(); + + // select + TMessageHolder Read(TMessageHolder message, uint64_t index) override; + // insert, update, create + TMessageHolder Write(TMessageHolder message, uint64_t index) override; + // convert request to log message + TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; + +private: + bool Execute(const std::string& q); + static int Process(void* self, int ncols, char** values, char** cols); + TMessageHolder Reply(const std::string& ans, uint64_t index); + + TResult Result; + std::string LastError; + uint64_t LastAppliedIndex = 0; + sqlite3* Db = nullptr; +}; + +TSql::TSql(const std::string& path, int serverId) +{ + std::string dbPath = path + "." + std::to_string(serverId); + if (sqlite3_open(dbPath.c_str(), &Db) != SQLITE_OK) { + std::cerr << "Cannot open db: `" << dbPath << "', " + << "error: " << sqlite3_errmsg(Db) + << std::endl; + throw std::runtime_error("Cannot open db"); + } + Execute(R"__(CREATE TABLE IF NOT EXISTS raft_metadata_ (key TEXT PRIMARY KEY, value TEXT))__"); + Execute(R"__(SELECT value FROM raft_metadata_ WHERE key = 'LastAppliedIndex')__"); + if (!Result.Empty()) { + LastAppliedIndex = std::stoi(*Result.Rows[0].Values[0]); + } + std::cerr << "LastAppliedIndex: " << LastAppliedIndex << std::endl; +} + +TSql::~TSql() +{ + if (Db) { + if (sqlite3_close(Db) != SQLITE_OK) { + std::cerr << "Failed to close db, error:" << sqlite3_errmsg(Db) << std::endl; + } + } +} + +int TSql::Process(void* self, int ncols, char** values, char** cols) { + TSql* this_ = (TSql*)self; + if (this_->Result.Cols.empty()) { + for (int i = 0; i < ncols; i++) { + this_->Result.Cols.emplace_back(cols[i]); + } + } + TRow row; + for (int i = 0; i < ncols; i++) { + if (values[i]) { + row.Values.emplace_back(values[i]); + } else { + row.Values.emplace_back(std::nullopt); + } + } + this_->Result.Rows.emplace_back(std::move(row)); + return 0; +} + +bool TSql::Execute(const std::string& q) { + char* err = nullptr; + std::cerr << "Execute: " << q << std::endl; + Result.Clear(); + LastError.clear(); + if (sqlite3_exec(Db, q.c_str(), Process, this, &err) != SQLITE_OK) { + std::cerr << "Cannot apply changes, error: " << err << std::endl; + LastError = err; + sqlite3_free(err); + return false; + } + std::cerr << "OK" << std::endl; + return true; +} + +TMessageHolder TSql::Read(TMessageHolder message, uint64_t index) { + auto readSql = message.Cast(); + if (!Execute(std::string(readSql->Query, readSql->QuerySize))) { + return Reply(LastError, index); + } else { + std::string text; + for (int i = 0; i < Result.Cols.size(); i++) { + text += Result.Cols[i]; + if (i != Result.Cols.size() - 1) { + text += ","; + } + } + text += "\n"; + for (int j = 0; j < Result.Rows.size(); j++) { + for (int i = 0; i < Result.Cols.size(); i++) { + text += Result.Rows[j].Values[i] ? *Result.Rows[j].Values[i] : "null"; + if (i != Result.Cols.size() - 1) { + text += ","; + } + } + text += "\n"; + } + return Reply(text, index); + } +} + +TMessageHolder TSql::Write(TMessageHolder message, uint64_t index) { + // TODO: index + 1 == LastAppliedIndex + + std::string err; + if (LastAppliedIndex < index) { + auto entry = message.Cast(); + std::cerr << "Execute write of size: " << entry->QuerySize << std::endl; + std::string updateLastApplied; + updateLastApplied += "INSERT INTO raft_metadata_ (key, value) VALUES ('LastAppliedIndex','" + std::to_string(index) + "')\n"; + updateLastApplied += "ON CONFLICT(key) DO UPDATE SET value = '" + std::to_string(index) + "';\n"; + std::string q = "BEGIN TRANSACTION;\n"; + q += std::string(entry->Query, entry->QuerySize); + if (q.back() != ';') { + q += ";\n"; + } + q += updateLastApplied; + q += "COMMIT;"; + if (Execute(q)) { + LastAppliedIndex = index; + } else { + err = LastError; + Execute("ROLLBACK;"); + Execute(updateLastApplied); // need to update LastAppliedIndex in order not to execute failed query aqain + } + } + return Reply(err, index); +} + +TMessageHolder TSql::Reply(const std::string& ans, uint64_t index) +{ + auto res = NewHoldedMessage(sizeof(TCommandResponse)+ans.size()); + res->Index = index; + memcpy(res->Data, ans.data(), ans.size()); + return res; +} + +TMessageHolder TSql::Prepare(TMessageHolder command, uint64_t term) { + auto dataSize = command->Len - sizeof(TCommandRequest); + std::cerr << "Prepare entry of size: " << dataSize << ", in term: " << term << std::endl; + auto entry = NewHoldedMessage(sizeof(TLogEntry)+dataSize); + memcpy(entry->Data, command->Data, dataSize); + entry->Term = term; + return entry; +} + +template +NNet::TFuture Client(TPoller& poller, TSocket socket) { + using TFileHandle = typename TPoller::TFileHandle; + TFileHandle input{0, poller}; // stdin + co_await socket.Connect(); + std::cout << "Connected\n"; + + NNet::TLine line; + TCommandRequest header; + header.Flags = TCommandRequest::EWrite; + header.Type = static_cast(TCommandRequest::MessageType); + auto lineReader = NNet::TLineReader(input, 2*1024); + auto byteWriter = NNet::TByteWriter(socket); + const char* sep = " \t\r\n"; + + try { + while ((line = co_await lineReader.Read())) { + std::string strLine; + strLine += std::string_view(line.Part1.data(), line.Part1.size()); + strLine += std::string_view(line.Part2.data(), line.Part2.size()); + size_t pos = strLine.find(' '); + auto prefix = strLine.substr(0, pos); + TMessageHolder req; + + int flags = 0; + if (!strcasecmp(prefix.data(), "create") || !strcasecmp(prefix.data(), "insert") || !strcasecmp(prefix.data(), "update")) { + auto mes = NewHoldedMessage(sizeof(TWriteSql) + strLine.size()); + mes->Flags = TCommandRequest::EWrite; + mes->QuerySize = strLine.size(); + memcpy(mes->Query, strLine.data(), strLine.size()); + req = mes; + } else if (!strcasecmp(prefix.data(), "select")) { + auto mes = NewHoldedMessage(sizeof(TReadSql) + strLine.size()); + mes->QuerySize = strLine.size(); + memcpy(mes->Query, strLine.data(), strLine.size()); + req = mes; + } else { + std::cerr << "Cannot parse command: " << strLine << std::endl; + continue; + } + co_await TMessageWriter(socket).Write(std::move(req)); + auto reply = co_await TMessageReader(socket).Read(); + auto res = reply.template Cast(); + auto len = res->Len - sizeof(TCommandResponse); + std::string_view data(res->Data, len); + std::cout << "commitIndex: " << res->Index << "\n"; + if (!data.empty()) { + std::cout << data << "\n"; + } + } + } catch (const std::exception& ex) { + std::cout << "Exception: " << ex.what() << "\n"; + } + co_return; +} + +void usage(const char* prog) { + std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...]" << "\n"; + exit(0); +} + +int main(int argc, char** argv) +{ + signal(SIGPIPE, SIG_IGN); + std::vector hosts; + THost myHost; + TNodeDict nodes; + uint32_t id = 0; + bool server = false; + for (int i = 1; i < argc; i++) { + if (!strcmp(argv[i], "--server")) { + server = true; + } else if (!strcmp(argv[i], "--node") && i < argc - 1) { + // address:port:id + hosts.push_back(THost{argv[++i]}); + } else if (!strcmp(argv[i], "--id") && i < argc - 1) { + id = atoi(argv[++i]); + } else if (!strcmp(argv[i], "--help")) { + usage(argv[0]); + } + } + + using TPoller = NNet::TDefaultPoller; + std::shared_ptr timeSource = std::make_shared(); + NNet::TLoop loop; + + if (server) { + for (auto& host : hosts) { + if (!host) { + std::cerr << "Empty host\n"; return 1; + } + if (host.Id == id) { + myHost = host; + } else { + nodes[host.Id] = std::make_shared>( + [&](const NNet::TAddress& addr) { return TPoller::TSocket(addr, loop.Poller()); }, + std::to_string(host.Id), + NNet::TAddress{host.Address, host.Port}, + timeSource); + } + } + + if (!myHost) { + std::cerr << "Host not found\n"; return 1; + } + + std::shared_ptr rsm = std::make_shared("sql_file.db", myHost.Id); + auto state = std::make_shared("sql_state", myHost.Id); + auto raft = std::make_shared(rsm, state, myHost.Id, nodes); + TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller()); + socket.Bind(); + socket.Listen(); + TRaftServer server(loop.Poller(), std::move(socket), raft, nodes, timeSource); + server.Serve(); + loop.Loop(); + } else { + NNet::TAddress addr{hosts[0].Address, hosts[0].Port}; + NNet::TSocket socket(std::move(addr), loop.Poller()); + + auto h = Client(loop.Poller(), std::move(socket)); + while (!h.done()) { + loop.Step(); + } + } + return 0; +} + diff --git a/src/raft.cpp b/src/raft.cpp index 25021c7..9187276 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -40,12 +40,13 @@ TMessageHolder TDummyRsm::Read(TMessageHolder message } } -void TDummyRsm::Write(TMessageHolder message, uint64_t index) +TMessageHolder TDummyRsm::Write(TMessageHolder message, uint64_t index) { if (LastAppliedIndex < index) { Log.emplace_back(std::move(message)); LastAppliedIndex = index; } + return {}; } TMessageHolder TDummyRsm::Prepare(TMessageHolder command, uint64_t term) @@ -276,7 +277,7 @@ void TRaft::OnCommandRequest(TMessageHolder command, const std: } auto index = State->LastLogIndex; if (replyTo) { - waiting.emplace(TWaiting{index, std::move(command), replyTo}); + Waiting.emplace(TWaiting{index, std::move(command), replyTo}); } } @@ -386,18 +387,28 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder message, con void TRaft::ProcessCommitted() { auto commitIndex = VolatileState->CommitIndex; for (auto i = VolatileState->LastApplied+1; i <= commitIndex; i++) { - Rsm->Write(State->Get(i-1), i); + auto reply = Rsm->Write(State->Get(i-1), i); + WriteAnswers.emplace(TAnswer { + .Index = i, + .Reply = reply ? reply : NewHoldedMessage(TCommandResponse {.Index = i}) + }); } VolatileState->LastApplied = commitIndex; } void TRaft::ProcessWaiting() { auto lastApplied = VolatileState->LastApplied; - while (!waiting.empty() && waiting.top().Index <= lastApplied) { - auto w = waiting.top(); waiting.pop(); + while (!Waiting.empty() && Waiting.top().Index <= lastApplied) { + auto w = Waiting.top(); Waiting.pop(); TMessageHolder reply; if (w.Command->Flags & TCommandRequest::EWrite) { - reply = NewHoldedMessage(TCommandResponse {.Index = w.Index}); + while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) { + WriteAnswers.pop(); + } + assert(!WriteAnswers.empty()); + auto answer = std::move(WriteAnswers.front()); WriteAnswers.pop(); + assert(answer.Index == w.Index); + reply = std::move(answer.Reply); } else { reply = Rsm->Read(std::move(w.Command), w.Index); } diff --git a/src/raft.h b/src/raft.h index aed23dd..b22ae38 100644 --- a/src/raft.h +++ b/src/raft.h @@ -22,13 +22,13 @@ struct INode { struct IRsm { virtual ~IRsm() = default; virtual TMessageHolder Read(TMessageHolder message, uint64_t index) = 0; - virtual void Write(TMessageHolder message, uint64_t index) = 0; + virtual TMessageHolder Write(TMessageHolder message, uint64_t index) = 0; virtual TMessageHolder Prepare(TMessageHolder message, uint64_t term) = 0; }; struct TDummyRsm: public IRsm { TMessageHolder Read(TMessageHolder message, uint64_t index) override; - void Write(TMessageHolder message, uint64_t index) override; + TMessageHolder Write(TMessageHolder message, uint64_t index) override; TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; private: @@ -148,8 +148,15 @@ class TRaft { return Index > other.Index; } }; - std::priority_queue waiting; + std::priority_queue Waiting; + + struct TAnswer { + uint64_t Index; + TMessageHolder Reply; + }; + std::queue WriteAnswers; EState StateName; uint32_t Seed = 31337; }; +