Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple distributed sql #20

Merged
merged 12 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set(CMAKE_CXX_STANDARD 20)

find_package(PkgConfig REQUIRED)
pkg_check_modules(CMOCKA REQUIRED cmocka)
find_package(SQLite3)

add_subdirectory(coroio)

Expand All @@ -35,6 +36,12 @@ target_link_libraries(server miniraft coroio)
target_link_libraries(client miniraft coroio)
target_link_libraries(kv miniraft coroio)

if (SQLite3_FOUND)
add_executable(sql examples/sql.cpp)
target_include_directories(sql PRIVATE ${SQLite3_INCLUDE_DIRS})
target_link_libraries(sql PRIVATE ${SQLite3_LIBRARIES} miniraft coroio)
endif()

target_include_directories(test_raft PRIVATE ${CMOCKA_INCLUDE_DIRS})
target_link_directories(test_raft PRIVATE ${CMOCKA_LIBRARY_DIRS})
target_link_libraries(test_raft miniraft coroio ${CMOCKA_LIBRARIES})
Expand Down
3 changes: 2 additions & 1 deletion examples/kv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TMessageHolder<TMessage> TKv::Read(TMessageHolder<TCommandRequest> message, uint
}
}

void TKv::Write(TMessageHolder<TLogEntry> message, uint64_t index) {
TMessageHolder<TMessage> TKv::Write(TMessageHolder<TLogEntry> message, uint64_t index) {
if (LastAppliedIndex < index) {
auto entry = message.Cast<TKvLogEntry>();
std::string_view k(entry->TKvEntry::Data, entry->KeySize);
Expand All @@ -63,6 +63,7 @@ void TKv::Write(TMessageHolder<TLogEntry> message, uint64_t index) {
}
LastAppliedIndex = index;
}
return {};
}

TMessageHolder<TLogEntry> TKv::Prepare(TMessageHolder<TCommandRequest> command, uint64_t term) {
Expand Down
2 changes: 1 addition & 1 deletion examples/kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class TKv: public IRsm {
public:
TMessageHolder<TMessage> Read(TMessageHolder<TCommandRequest> message, uint64_t index) override;
void Write(TMessageHolder<TLogEntry> message, uint64_t index) override;
TMessageHolder<TMessage> Write(TMessageHolder<TLogEntry> message, uint64_t index) override;
TMessageHolder<TLogEntry> Prepare(TMessageHolder<TCommandRequest> message, uint64_t term) override;

private:
Expand Down
323 changes: 323 additions & 0 deletions examples/sql.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
#include <sqlite3.h>
#include <iostream>

#include <raft.h>
#include <persist.h>
#include <server.h>

struct TSqlEntry {
uint32_t QuerySize = 0;
char Query[0];
};

struct TSqlLogEntry: public TLogEntry, public TSqlEntry
{
};

struct TWriteSql: public TCommandRequest, public TSqlEntry
{
};

struct TReadSql: public TCommandRequest, public TSqlEntry
{
};

struct TRow {
std::vector<std::optional<std::string>> Values;
};

struct TResult {
std::vector<std::string> Cols;
std::vector<TRow> Rows;

void Clear() {
Cols.clear();
Rows.clear();
}

bool Empty() {
return Cols.empty();
}
};

class TSql: public IRsm {
public:
TSql(const std::string& path, int serverId);
~TSql();

// select
TMessageHolder<TMessage> Read(TMessageHolder<TCommandRequest> message, uint64_t index) override;
// insert, update, create
TMessageHolder<TMessage> Write(TMessageHolder<TLogEntry> message, uint64_t index) override;
// convert request to log message
TMessageHolder<TLogEntry> Prepare(TMessageHolder<TCommandRequest> message, uint64_t term) override;

private:
bool Execute(const std::string& q);
static int Process(void* self, int ncols, char** values, char** cols);
TMessageHolder<TMessage> Reply(const std::string& ans, uint64_t index);

TResult Result;
std::string LastError;
uint64_t LastAppliedIndex = 0;
sqlite3* Db = nullptr;
};

TSql::TSql(const std::string& path, int serverId)
{
std::string dbPath = path + "." + std::to_string(serverId);
if (sqlite3_open(dbPath.c_str(), &Db) != SQLITE_OK) {
std::cerr << "Cannot open db: `" << dbPath << "', "
<< "error: " << sqlite3_errmsg(Db)
<< std::endl;
throw std::runtime_error("Cannot open db");
}
Execute(R"__(CREATE TABLE IF NOT EXISTS raft_metadata_ (key TEXT PRIMARY KEY, value TEXT))__");
Execute(R"__(SELECT value FROM raft_metadata_ WHERE key = 'LastAppliedIndex')__");
if (!Result.Empty()) {
LastAppliedIndex = std::stoi(*Result.Rows[0].Values[0]);
}
std::cerr << "LastAppliedIndex: " << LastAppliedIndex << std::endl;
}

TSql::~TSql()
{
if (Db) {
if (sqlite3_close(Db) != SQLITE_OK) {
std::cerr << "Failed to close db, error:" << sqlite3_errmsg(Db) << std::endl;
}
}
}

int TSql::Process(void* self, int ncols, char** values, char** cols) {
TSql* this_ = (TSql*)self;
if (this_->Result.Cols.empty()) {
for (int i = 0; i < ncols; i++) {
this_->Result.Cols.emplace_back(cols[i]);
}
}
TRow row;
for (int i = 0; i < ncols; i++) {
if (values[i]) {
row.Values.emplace_back(values[i]);
} else {
row.Values.emplace_back(std::nullopt);
}
}
this_->Result.Rows.emplace_back(std::move(row));
return 0;
}

bool TSql::Execute(const std::string& q) {
char* err = nullptr;
std::cerr << "Execute: " << q << std::endl;
Result.Clear();
LastError.clear();
if (sqlite3_exec(Db, q.c_str(), Process, this, &err) != SQLITE_OK) {
std::cerr << "Cannot apply changes, error: " << err << std::endl;
LastError = err;
sqlite3_free(err);
return false;
}
std::cerr << "OK" << std::endl;
return true;
}

TMessageHolder<TMessage> TSql::Read(TMessageHolder<TCommandRequest> message, uint64_t index) {
auto readSql = message.Cast<TReadSql>();
if (!Execute(std::string(readSql->Query, readSql->QuerySize))) {
return Reply(LastError, index);
} else {
std::string text;
for (int i = 0; i < Result.Cols.size(); i++) {
text += Result.Cols[i];
if (i != Result.Cols.size() - 1) {
text += ",";
}
}
text += "\n";
for (int j = 0; j < Result.Rows.size(); j++) {
for (int i = 0; i < Result.Cols.size(); i++) {
text += Result.Rows[j].Values[i] ? *Result.Rows[j].Values[i] : "null";
if (i != Result.Cols.size() - 1) {
text += ",";
}
}
text += "\n";
}
return Reply(text, index);
}
}

TMessageHolder<TMessage> TSql::Write(TMessageHolder<TLogEntry> message, uint64_t index) {
// TODO: index + 1 == LastAppliedIndex

std::string err;
if (LastAppliedIndex < index) {
auto entry = message.Cast<TSqlLogEntry>();
std::cerr << "Execute write of size: " << entry->QuerySize << std::endl;
std::string updateLastApplied;
updateLastApplied += "INSERT INTO raft_metadata_ (key, value) VALUES ('LastAppliedIndex','" + std::to_string(index) + "')\n";
updateLastApplied += "ON CONFLICT(key) DO UPDATE SET value = '" + std::to_string(index) + "';\n";
std::string q = "BEGIN TRANSACTION;\n";
q += std::string(entry->Query, entry->QuerySize);
if (q.back() != ';') {
q += ";\n";
}
q += updateLastApplied;
q += "COMMIT;";
if (Execute(q)) {
LastAppliedIndex = index;
} else {
err = LastError;
Execute("ROLLBACK;");
Execute(updateLastApplied); // need to update LastAppliedIndex in order not to execute failed query aqain
}
}
return Reply(err, index);
}

TMessageHolder<TMessage> TSql::Reply(const std::string& ans, uint64_t index)
{
auto res = NewHoldedMessage<TCommandResponse>(sizeof(TCommandResponse)+ans.size());
res->Index = index;
memcpy(res->Data, ans.data(), ans.size());
return res;
}

TMessageHolder<TLogEntry> TSql::Prepare(TMessageHolder<TCommandRequest> command, uint64_t term) {
auto dataSize = command->Len - sizeof(TCommandRequest);
std::cerr << "Prepare entry of size: " << dataSize << ", in term: " << term << std::endl;
auto entry = NewHoldedMessage<TLogEntry>(sizeof(TLogEntry)+dataSize);
memcpy(entry->Data, command->Data, dataSize);
entry->Term = term;
return entry;
}

template<typename TPoller, typename TSocket>
NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
using TFileHandle = typename TPoller::TFileHandle;
TFileHandle input{0, poller}; // stdin
co_await socket.Connect();
std::cout << "Connected\n";

NNet::TLine line;
TCommandRequest header;
header.Flags = TCommandRequest::EWrite;
header.Type = static_cast<uint32_t>(TCommandRequest::MessageType);
auto lineReader = NNet::TLineReader(input, 2*1024);
auto byteWriter = NNet::TByteWriter(socket);
const char* sep = " \t\r\n";

try {
while ((line = co_await lineReader.Read())) {
std::string strLine;
strLine += std::string_view(line.Part1.data(), line.Part1.size());
strLine += std::string_view(line.Part2.data(), line.Part2.size());
size_t pos = strLine.find(' ');
auto prefix = strLine.substr(0, pos);
TMessageHolder<TMessage> req;

int flags = 0;
if (!strcasecmp(prefix.data(), "create") || !strcasecmp(prefix.data(), "insert") || !strcasecmp(prefix.data(), "update")) {
auto mes = NewHoldedMessage<TWriteSql>(sizeof(TWriteSql) + strLine.size());
mes->Flags = TCommandRequest::EWrite;
mes->QuerySize = strLine.size();
memcpy(mes->Query, strLine.data(), strLine.size());
req = mes;
} else if (!strcasecmp(prefix.data(), "select")) {
auto mes = NewHoldedMessage<TReadSql>(sizeof(TReadSql) + strLine.size());
mes->QuerySize = strLine.size();
memcpy(mes->Query, strLine.data(), strLine.size());
req = mes;
} else {
std::cerr << "Cannot parse command: " << strLine << std::endl;
continue;
}
co_await TMessageWriter(socket).Write(std::move(req));
auto reply = co_await TMessageReader(socket).Read();
auto res = reply.template Cast<TCommandResponse>();
auto len = res->Len - sizeof(TCommandResponse);
std::string_view data(res->Data, len);
std::cout << "commitIndex: " << res->Index << "\n";
if (!data.empty()) {
std::cout << data << "\n";
}
}
} catch (const std::exception& ex) {
std::cout << "Exception: " << ex.what() << "\n";
}
co_return;
}

void usage(const char* prog) {
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...]" << "\n";
exit(0);
}

int main(int argc, char** argv)
{
signal(SIGPIPE, SIG_IGN);
std::vector<THost> hosts;
THost myHost;
TNodeDict nodes;
uint32_t id = 0;
bool server = false;
for (int i = 1; i < argc; i++) {
if (!strcmp(argv[i], "--server")) {
server = true;
} else if (!strcmp(argv[i], "--node") && i < argc - 1) {
// address:port:id
hosts.push_back(THost{argv[++i]});
} else if (!strcmp(argv[i], "--id") && i < argc - 1) {
id = atoi(argv[++i]);
} else if (!strcmp(argv[i], "--help")) {
usage(argv[0]);
}
}

using TPoller = NNet::TDefaultPoller;
std::shared_ptr<ITimeSource> timeSource = std::make_shared<TTimeSource>();
NNet::TLoop<TPoller> loop;

if (server) {
for (auto& host : hosts) {
if (!host) {
std::cerr << "Empty host\n"; return 1;
}
if (host.Id == id) {
myHost = host;
} else {
nodes[host.Id] = std::make_shared<TNode<TPoller::TSocket>>(
[&](const NNet::TAddress& addr) { return TPoller::TSocket(addr, loop.Poller()); },
std::to_string(host.Id),
NNet::TAddress{host.Address, host.Port},
timeSource);
}
}

if (!myHost) {
std::cerr << "Host not found\n"; return 1;
}

std::shared_ptr<IRsm> rsm = std::make_shared<TSql>("sql_file.db", myHost.Id);
auto state = std::make_shared<TDiskState>("sql_state", myHost.Id);
auto raft = std::make_shared<TRaft>(rsm, state, myHost.Id, nodes);
TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller());
socket.Bind();
socket.Listen();
TRaftServer server(loop.Poller(), std::move(socket), raft, nodes, timeSource);
server.Serve();
loop.Loop();
} else {
NNet::TAddress addr{hosts[0].Address, hosts[0].Port};
NNet::TSocket socket(std::move(addr), loop.Poller());

auto h = Client(loop.Poller(), std::move(socket));
while (!h.done()) {
loop.Step();
}
}
return 0;
}

Loading
Loading