Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply log entries to the state machine #3

Merged
merged 1 commit into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Feedback always welcome!
TBD

Missing:
- Apply commands to the state machine
- Safety - 5.4 and onwards (see https://raft.github.io/raft.pdf)
- Memberships
- RPC. Currently network is emulated in `TestCluster` by introducing non-Byzantine failures to message passing.
Expand Down
51 changes: 32 additions & 19 deletions src/main/scala/com.ariskk.raft/Raft.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import zio.clock._
import com.ariskk.raft.model._
import Message._
import Raft.{ MessageQueues }
import com.ariskk.raft.state.VolatileState
import com.ariskk.raft.volatile.VolatileState
import com.ariskk.raft.storage.Storage
import com.ariskk.raft.statemachine.StateMachine

Expand All @@ -20,8 +20,8 @@ import com.ariskk.raft.statemachine.StateMachine
final class Raft[T](
val nodeId: NodeId,
state: VolatileState,
storage: Storage[T],
queues: MessageQueues[T],
storage: Storage,
queues: MessageQueues,
stateMachine: StateMachine[T]
) {

Expand All @@ -33,7 +33,7 @@ final class Raft[T](

def offerAppendEntriesResponse(r: AppendEntriesResponse) = queues.appendResponseQueue.offer(r).commit

def offerAppendEntries(entries: AppendEntries[T]) = queues.appendEntriesQueue.offer(entries).commit
def offerAppendEntries(entries: AppendEntries) = queues.appendEntriesQueue.offer(entries).commit

def offerVoteRequest(request: VoteRequest) = queues.inboundVoteRequestQueue.offer(request).commit

Expand Down Expand Up @@ -69,7 +69,7 @@ final class Raft[T](
storage.getEntry(previousIndex).map(_.map(_.term).getOrElse(Term(-1)))
entries <- storage.getEntries(nextIndex)
_ <- sendMessage(
AppendEntries[T](
AppendEntries(
AppendEntries.newUniqueId,
nodeId,
p,
Expand Down Expand Up @@ -234,22 +234,28 @@ final class Raft[T](
_ <- if (!success) storage.purgeFrom(previousIndex) else ZSTM.unit
} yield success

private def appendLog(ae: AppendEntries[T]) =
private def appendLog(ae: AppendEntries) =
shouldAppend(ae.prevLogIndex, ae.prevLogTerm).flatMap { shouldAppend =>
if (shouldAppend) for {
_ <- storage.appendEntries(ae.entries.toList)
currentCommit <- state.lastCommitIndex
_ <-
if (ae.leaderCommitIndex > currentCommit)
storage.logSize.flatMap { size =>
state.updateCommitIndex(Index(math.min(size - 1, ae.leaderCommitIndex.index)))
val newCommitIndex = Index(math.min(size - 1, ae.leaderCommitIndex.index))
state.updateCommitIndex(newCommitIndex) *>
storage.getRange(currentCommit.increment, newCommitIndex).flatMap { entries =>
ZSTM.collectAll(
entries.map(e => stateMachine.write(e.command) *> state.incrementLastApplied)
)
}
}
else STM.unit
} yield ()
else STM.unit
}

private def processEntries(ae: AppendEntries[T]) = (for {
private def processEntries(ae: AppendEntries) = (for {
currentTerm <- storage.getTerm
currentState <- state.nodeState
lastIndex <- storage.logSize.map(x => Index(x - 1))
Expand Down Expand Up @@ -304,7 +310,10 @@ final class Raft[T](

def run = runFollowerLoop

private def processCommand(command: Command[T]) = {
private def applyToStateMachine(command: WriteCommand) =
stateMachine.write(command) *> state.incrementLastApplied

private def processCommand(command: WriteCommand) = {
lazy val processCommandProgram = for {
term <- storage.getTerm
_ <- storage.appendEntry(LogEntry(command, term))
Expand All @@ -314,14 +323,15 @@ final class Raft[T](
processCommandProgram.commit.flatMap { logSize =>
state.lastCommitIndex.commit
.repeatUntil(_.index >= logSize - 1)
.flatMap(_ => applyToStateMachine(command).commit)
.map(_ => Committed)
}
}

/**
* Blocks until it gets committed.
*/
def submitCommand(command: Command[T]): ZIO[Clock, RaftException, CommandResponse] =
def submitCommand(command: WriteCommand): ZIO[Clock, RaftException, CommandResponse] =
state.leader.commit.flatMap { leader =>
leader match {
case Some(leaderId) if leaderId == nodeId => processCommand(command)
Expand All @@ -330,15 +340,18 @@ final class Raft[T](
}
}

def submitQuery(query: ReadCommand): ZIO[Clock, RaftException, Option[T]] =
stateMachine.read(query).commit

}

object Raft {

case class MessageQueues[T](
case class MessageQueues(
inboundVoteResponseQueue: TQueue[VoteResponse],
inboundVoteRequestQueue: TQueue[VoteRequest],
appendResponseQueue: TQueue[AppendEntriesResponse],
appendEntriesQueue: TQueue[AppendEntries[T]],
appendEntriesQueue: TQueue[AppendEntries],
outboundQueue: TQueue[Message]
)

Expand All @@ -349,13 +362,13 @@ object Raft {
private def newQueue[T](queueSize: Int) =
TQueue.bounded[T](queueSize).commit

def default[T] = apply[T](DefaultQueueSize)
def default = apply(DefaultQueueSize)

def apply[T](queueSize: Int): UIO[MessageQueues[T]] = for {
def apply(queueSize: Int): UIO[MessageQueues] = for {
inboundVoteResponseQueue <- newQueue[VoteResponse](queueSize)
inboundVoteRequestQueue <- newQueue[VoteRequest](queueSize)
appendResponseQueue <- newQueue[AppendEntriesResponse](queueSize)
appendEntriesQueue <- newQueue[AppendEntries[T]](queueSize)
appendEntriesQueue <- newQueue[AppendEntries](queueSize)
outboundQueue <- newQueue[Message](queueSize)
} yield MessageQueues(
inboundVoteResponseQueue,
Expand All @@ -368,22 +381,22 @@ object Raft {

val LeaderHeartbeat = 50.milliseconds

def default[T](storage: Storage[T], stateMachine: StateMachine[T]): UIO[Raft[T]] = {
def default[T](storage: Storage, stateMachine: StateMachine[T]): UIO[Raft[T]] = {
val id = NodeId.newUniqueId
for {
state <- VolatileState(id, Set.empty[NodeId])
queues <- MessageQueues.default[T]
queues <- MessageQueues.default
} yield new Raft[T](id, state, storage, queues, stateMachine)
}

def apply[T](
nodeId: NodeId,
peers: Set[NodeId],
storage: Storage[T],
storage: Storage,
stateMachine: StateMachine[T]
) = for {
state <- VolatileState(nodeId, peers)
queues <- MessageQueues.default[T]
queues <- MessageQueues.default
} yield new Raft[T](nodeId, state, storage, queues, stateMachine)

}
8 changes: 3 additions & 5 deletions src/main/scala/com.ariskk.raft/model/Command.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.ariskk.raft.model

case class Key(value: String) extends AnyVal

sealed trait Command[T]
case class ReadCommand[T](key: Key) extends Command[T]
case class WriteCommand[T](key: Key, value: T) extends Command[T]
sealed trait Command
trait ReadCommand extends Command
trait WriteCommand extends Command
6 changes: 3 additions & 3 deletions src/main/scala/com.ariskk.raft/model/LogEntry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package com.ariskk.raft.model

import com.ariskk.raft.utils.Utils

case class LogEntry[T](
case class LogEntry(
id: LogEntry.Id,
command: Command[T],
command: WriteCommand,
term: Term
)

Expand All @@ -13,7 +13,7 @@ object LogEntry {

def newUniqueId = Id(Utils.newPrefixedId("entry"))

def apply[T](command: Command[T], term: Term): LogEntry[T] = LogEntry[T](
def apply(command: WriteCommand, term: Term): LogEntry = LogEntry(
newUniqueId,
command,
term
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com.ariskk.raft/model/Message.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ object Message {
def newUniqueId = Id(Utils.newPrefixedId("append"))
}

final case class AppendEntries[T](
final case class AppendEntries(
appendId: AppendEntries.Id,
from: NodeId,
to: NodeId,
term: Term,
prevLogIndex: Index,
prevLogTerm: Term,
leaderCommitIndex: Index,
entries: Seq[LogEntry[T]]
entries: Seq[LogEntry]
) extends Message

final case class AppendEntriesResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ import zio.stm._

import com.ariskk.raft.model._

final case class Key(value: String) extends AnyVal
final case class ReadKey(key: Key) extends ReadCommand
final case class WriteKey[T](key: Key, value: T) extends WriteCommand

final class KeyValueStore[T](map: TMap[Key, T]) extends StateMachine[T] {
override def write(command: WriteCommand[T]) = map.put(command.key, command.value)
override def read(command: ReadCommand[T]) = map.get(command.key)
def write = { case WriteKey(key: Key, value: T) => map.put(key, value) }
def read = { case ReadKey(key) => map.get(key) }
}

object KeyValueStore {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ import com.ariskk.raft.model._
* Very simplistic modeling
*/
trait StateMachine[T] {
def write(command: WriteCommand[T]): STM[StateMachineException, Unit]
def read(command: ReadCommand[T]): STM[StateMachineException, Option[T]]
def write: PartialFunction[WriteCommand, STM[StateMachineException, Unit]]
def read: PartialFunction[ReadCommand, STM[StateMachineException, Option[T]]]
}
9 changes: 5 additions & 4 deletions src/main/scala/com.ariskk.raft/storage/Log.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import zio.stm._

import com.ariskk.raft.model._

trait Log[T] {
def append(entry: LogEntry[T]): STM[StorageException, Unit]
trait Log {
def append(entry: LogEntry): STM[StorageException, Unit]
def size: STM[StorageException, Long]
def getEntry(index: Index): STM[StorageException, Option[LogEntry[T]]]
def getEntries(index: Index): STM[StorageException, List[LogEntry[T]]]
def getEntry(index: Index): STM[StorageException, Option[LogEntry]]
def getEntries(index: Index): STM[StorageException, List[LogEntry]]
def purgeFrom(index: Index): STM[StorageException, Unit]
def getRange(from: Index, to: Index): STM[StorageException, List[LogEntry]]
}
18 changes: 10 additions & 8 deletions src/main/scala/com.ariskk.raft/storage/MemoryStorage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import com.ariskk.raft.model._
* Reference implemantation of `Storage` for testing purposes.
*/
final class MemoryStorage[T](
private[raft] listRef: TRef[List[LogEntry[T]]],
private[raft] listRef: TRef[List[LogEntry]],
private[raft] votedForRef: TRef[Option[Vote]],
private[raft] termRef: TRef[Term]
) extends Storage[T] {
) extends Storage {
lazy val log = new MemoryLog(listRef)

def storeVote(vote: Vote): STM[StorageException, Unit] = votedForRef.set(Option(vote))
Expand All @@ -26,22 +26,24 @@ final class MemoryStorage[T](

}

final class MemoryLog[T](log: TRef[List[LogEntry[T]]]) extends Log[T] {
def append(entry: LogEntry[T]): STM[StorageException, Unit] = log.update(_ :+ entry)
def size: STM[StorageException, Long] = log.get.map(_.size.toLong)
def getEntry(index: Index): STM[StorageException, Option[LogEntry[T]]] =
final class MemoryLog(log: TRef[List[LogEntry]]) extends Log {
def append(entry: LogEntry): STM[StorageException, Unit] = log.update(_ :+ entry)
def size: STM[StorageException, Long] = log.get.map(_.size.toLong)
def getEntry(index: Index): STM[StorageException, Option[LogEntry]] =
log.get.map(_.lift(index.index.toInt))
def getEntries(fromIndex: Index): STM[StorageException, List[LogEntry[T]]] =
def getEntries(fromIndex: Index): STM[StorageException, List[LogEntry]] =
log.get.map(_.drop(fromIndex.index.toInt))
def purgeFrom(index: Index): STM[StorageException, Unit] = log.get.map(l => l.dropRight(l.size - index.index.toInt))
def getRange(from: Index, to: Index) =
log.get.map(_.slice(from.index.toInt, to.increment.index.toInt))
}

object MemoryStorage {

def default[T]: UIO[MemoryStorage[T]] = for {
termRef <- TRef.makeCommit(Term.Zero)
votedForRef <- TRef.makeCommit(Option.empty[Vote])
log <- TRef.makeCommit(List.empty[LogEntry[T]])
log <- TRef.makeCommit(List.empty[LogEntry])
} yield new MemoryStorage(log, votedForRef, termRef)

}
17 changes: 10 additions & 7 deletions src/main/scala/com.ariskk.raft/storage/Storage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,26 @@ import com.ariskk.raft.model._
* to stable storage before responding to RPC.
* For more info, look at Figure 2 of https://raft.github.io/raft.pdf
*/
trait Storage[T] {
def log: Log[T]
trait Storage {
def log: Log

def appendEntry(entry: LogEntry[T]): STM[StorageException, Unit] = log.append(entry)
def appendEntry(entry: LogEntry): STM[StorageException, Unit] = log.append(entry)

def appendEntries(entries: List[LogEntry[T]]): STM[StorageException, Unit] = ZSTM
def appendEntries(entries: List[LogEntry]): STM[StorageException, Unit] = ZSTM
.collectAll(
entries.map(appendEntry)
)
.unit

def getEntry(index: Index): STM[StorageException, Option[LogEntry[T]]] = log.getEntry(index)
def getEntry(index: Index): STM[StorageException, Option[LogEntry]] = log.getEntry(index)

def getEntries(fromIndex: Index): STM[StorageException, List[LogEntry[T]]] =
def getEntries(fromIndex: Index): STM[StorageException, List[LogEntry]] =
log.getEntries(fromIndex)

def lastEntry: STM[StorageException, Option[LogEntry[T]]] = for {
def getRange(from: Index, to: Index): STM[StorageException, List[LogEntry]] =
log.getRange(from, to)

def lastEntry: STM[StorageException, Option[LogEntry]] = for {
size <- logSize
last <- getEntry(Index(size - 1))
} yield last
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.ariskk.raft.state
package com.ariskk.raft.volatile

import zio.stm._
import zio.UIO
Expand Down Expand Up @@ -52,6 +52,10 @@ final class VolatileState(
_ <- ZSTM.collectAll(peers.map(matchIndex.put(_, Index(0))))
} yield ()

def setLastApplied(index: Index) = lastApplied.set(index)

def incrementLastApplied = lastApplied.update(_.increment)

def nodeState = state.get

def addPeer(id: NodeId) = peers.put(id)
Expand Down
Loading