diff --git a/example/server.c b/example/server.c index cd75c52c..8ccb40ec 100644 --- a/example/server.c +++ b/example/server.c @@ -100,6 +100,11 @@ static void FsmClose(struct raft_fsm *f) struct Server; typedef void (*ServerCloseCb)(struct Server *server); +typedef int (*raft_io_start_fn)(struct raft_io *io, + unsigned msecs, + raft_io_tick_cb tick, + raft_io_recv_cb recv); + struct Server { void *data; /* User data context. */ @@ -114,6 +119,8 @@ struct Server struct raft raft; /* Raft instance. */ struct raft_transfer transfer; /* Transfer leadership request. */ ServerCloseCb close_cb; /* Optional close callback. */ + raft_io_start_fn raft_start_fn; /* The original raft_io start function */ + raft_io_recv_cb raft_recv_cb; /* Original raft recv cb */ }; static void serverRaftCloseCb(struct raft *raft) @@ -153,6 +160,32 @@ static void serverTimerCloseCb(struct uv_handle_s *handle) } } +static void custom_recv_cb(struct raft_io *io, struct raft_message *msg) +{ + struct raft *r = io->data; + struct Server *s = r->data; + + if (msg->type == RAFT_REQUEST_CUSTOM) { + Logf(s->id, "Received custom request from %llu: %s", msg->server_id, + msg->server_address ? msg->server_address : ""); + return; + } + + s->raft_recv_cb(io, msg); +} + +static int custom_uv_start(struct raft_io *io, + unsigned msecs, + raft_io_tick_cb tick_cb, + raft_io_recv_cb recv_cb) +{ + struct raft *r = io->data; + struct Server *s = r->data; + s->raft_recv_cb = recv_cb; + + return s->raft_start_fn(io, msecs, tick_cb, &custom_recv_cb); +} + /* Initialize the example server struct, without starting it yet. */ static int ServerInit(struct Server *s, struct uv_loop_s *loop, @@ -194,6 +227,8 @@ static int ServerInit(struct Server *s, Logf(s->id, "raft_uv_init(): %s", s->io.errmsg); goto err_after_uv_tcp_init; } + s->raft_start_fn = s->io.start; + s->io.start = custom_uv_start; /* Initialize the finite state machine. */ rv = FsmInit(&s->fsm); @@ -275,6 +310,63 @@ static void serverApplyCb(struct raft_apply *req, int status, void *result) } } +static void sendCustomMessageCb(struct raft_io_send *send, int status) +{ + struct raft_message *message = send->data; + + (void)status; + + if (message) { + raft_free(message->request_custom.data.base); + raft_free(message); + } + raft_free(send); +} + +static void sendCustomMessageToLeader(struct Server *s) +{ + int rv; + raft_id leader_id; + const char *leader_address = NULL; + raft_leader(&s->raft, &leader_id, &leader_address); + if (!leader_address) { + return; + } + + struct raft_io_send *custom_req = raft_malloc(sizeof(struct raft_io_send)); + struct raft_message *message = raft_malloc(sizeof(struct raft_message)); + + if (!custom_req || !message) { + raft_free(custom_req); + raft_free(message); + Log(s->id, "serverTimerCb(): out of memory for custom request"); + return; + } + custom_req->data = message; + message->type = RAFT_REQUEST_CUSTOM; + message->server_id = leader_id; + message->server_address = leader_address; + message->request_custom.version = 1; + message->request_custom.term = 0; + message->request_custom.data.len = sizeof(uint64_t); + message->request_custom.data.base = raft_malloc(sizeof(uint64_t)); + if (!message->request_custom.data.base) { + sendCustomMessageCb(custom_req, 0); + Log(s->id, "serverTimerCb(): out of memory for custom request data"); + return; + } + + if (!s->io.send) { + Log(s->id, "serverTimerCb(): cannot send custom message"); + sendCustomMessageCb(custom_req, 0); + return; + } + rv = s->io.send(&s->io, custom_req, message, &sendCustomMessageCb); + if (rv != 0) { + Logf(s->id, "io.send(customMessage): %s", raft_errmsg(&s->raft)); + } +} + /* Called periodically every APPLY_RATE milliseconds. */ static void serverTimerCb(uv_timer_t *timer) { @@ -284,6 +376,7 @@ static void serverTimerCb(uv_timer_t *timer) int rv; if (s->raft.state != RAFT_LEADER) { + sendCustomMessageToLeader(s); return; } @@ -322,7 +415,7 @@ static int ServerStart(struct Server *s) Logf(s->id, "raft_start(): %s", raft_errmsg(&s->raft)); goto err; } - rv = uv_timer_start(&s->timer, serverTimerCb, 0, 125); + rv = uv_timer_start(&s->timer, serverTimerCb, 0, APPLY_RATE); if (rv != 0) { Logf(s->id, "uv_timer_start(): %s", uv_strerror(rv)); goto err;