Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate definition and execution of Petri Nets #153

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions akka/src/main/scala/io/kagera/akka/actor/PetriNetInstance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object PetriNetInstance {

def petriNetInstancePersistenceId(processId: String): String = s"process-$processId"

def instanceState[S](instance: Instance[S]): InstanceState[S] = {
def instanceState[S, T <: Transition[_, _, S]](instance: Instance[S, T]): InstanceState[S] = {
val failures = instance.failedJobs.map { e =>
e.transitionId -> PetriNetInstanceProtocol.ExceptionState(
e.consecutiveFailureCount,
Expand All @@ -40,20 +40,23 @@ object PetriNetInstance {
InstanceState[S](instance.sequenceNr, instance.marking, instance.state, failures)
}

def props[S](topology: ExecutablePetriNet[S], settings: Settings = defaultSettings): Props =
Props(new PetriNetInstance[S](topology, settings, new TransitionExecutorImpl[IO, S](topology)))
def props[S, T <: Transition[Any, _, S]](topology: ExecutablePetriNet[S, T], settings: Settings = defaultSettings)(
implicit executorFactory: TransitionExecutorFactory.WithInputOutputState[IO, T, Any, _, S]
): Props =
Props(new PetriNetInstance[S, T](topology, settings, new TransitionExecutorImpl[IO, T](topology)))
}

/**
* This actor is responsible for maintaining the state of a single petri net instance.
*/
class PetriNetInstance[S](
override val topology: ExecutablePetriNet[S],
class PetriNetInstance[S, T <: Transition[Any, _, S]](
override val topology: ExecutablePetriNet[S, T],
val settings: Settings,
executor: TransitionExecutor[IO, S]
) extends PersistentActor
executor: TransitionExecutor[IO, T]
)(implicit executorFactory: TransitionExecutorFactory.WithInputOutputState[IO, T, Any, _, S])
extends PersistentActor
with ActorLogging
with PetriNetInstanceRecovery[S] {
with PetriNetInstanceRecovery[S, T] {

import PetriNetInstance._

Expand All @@ -68,7 +71,7 @@ class PetriNetInstance[S](
def uninitialized: Receive = {
case msg @ Initialize(marking, state) =>
log.debug(s"Received message: {}", msg)
val uninitialized = Instance.uninitialized[S](topology)
val uninitialized = Instance.uninitialized[S, T](topology)
persistEvent(uninitialized, InitializedEvent(marking, state.asInstanceOf[S])) {
(applyEvent(uninitialized) _)
.andThen(step)
Expand All @@ -79,7 +82,7 @@ class PetriNetInstance[S](
context.stop(context.self)
}

def running(instance: Instance[S]): Receive = {
def running(instance: Instance[S, T]): Receive = {
case IdleStop(n) if n == instance.sequenceNr && instance.activeJobs.isEmpty =>
context.stop(context.self)

Expand All @@ -105,7 +108,7 @@ class PetriNetInstance[S](

val updatedInstance = applyEvent(instance)(e)

def updateAndRespond(instance: Instance[S]) = {
def updateAndRespond(instance: Instance[S, T]) = {
sender() ! TransitionFailed(transitionId, consume, input, reason, strategy)
context become running(instance)
}
Expand All @@ -116,7 +119,9 @@ class PetriNetInstance[S](
s"Scheduling a retry of transition '${topology.transitions.getById(transitionId)}' in $delay milliseconds"
)
val originalSender = sender()
system.scheduler.scheduleOnce(delay milliseconds) { executeJob(updatedInstance.jobs(jobId), originalSender) }
system.scheduler.scheduleOnce(delay milliseconds) {
executeJob(updatedInstance.jobs(jobId), originalSender)
}
updateAndRespond(applyEvent(instance)(e))
case _ =>
persistEvent(instance, e)((applyEvent(instance) _).andThen(updateAndRespond _))
Expand All @@ -125,7 +130,7 @@ class PetriNetInstance[S](
case msg @ FireTransition(id, input, correlationId) =>
log.debug(s"Received message: {}", msg)

fireTransitionById[S](id, input).run(instance).value match {
fireTransitionById[S, T](id, input).run(instance).value match {
case (updatedInstance, Right(job)) =>
executeJob(job, sender())
context become running(updatedInstance)
Expand All @@ -138,7 +143,7 @@ class PetriNetInstance[S](
}

// TODO remove side effecting here
def step(instance: Instance[S]): Instance[S] = {
def step(instance: Instance[S, T]): Instance[S, T] = {
fireAllEnabledTransitions.run(instance).value match {
case (updatedInstance, jobs) =>
if (jobs.isEmpty && updatedInstance.activeJobs.isEmpty)
Expand All @@ -153,8 +158,10 @@ class PetriNetInstance[S](
}
}

def executeJob[E](job: Job[S, E], originalSender: ActorRef) =
runJobAsync(job, executor)(settings.evaluationStrategy).unsafeToFuture().pipeTo(context.self)(originalSender)
def executeJob(job: Job[S, T], originalSender: ActorRef) =
runJobAsync[S, T](job, executor)(settings.evaluationStrategy, executorFactory)
.unsafeToFuture()
.pipeTo(context.self)(originalSender)

override def onRecoveryCompleted(instance: Instance[S]) = step(instance)
override def onRecoveryCompleted(instance: Instance[S, T]) = step(instance)
}
17 changes: 12 additions & 5 deletions akka/src/main/scala/io/kagera/akka/actor/PetriNetInstanceApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,26 @@ class QueuePushingActor[E](queue: SourceQueueWithComplete[E], takeWhile: Any =>

object PetriNetInstanceApi {

def hasAutomaticTransitions[S](topology: ExecutablePetriNet[S]): InstanceState[S] => Boolean = state => {
def hasAutomaticTransitions[S, T <: Transition[_, _, S]](
topology: ExecutablePetriNet[S, T]
): InstanceState[S] => Boolean = state => {
state.marking.keySet
.map(p => topology.outgoingTransitions(p))
.foldLeft(Set.empty[Transition[_, _, _]]) { case (result, transitions) =>
.foldLeft(Set.empty[T]) { case (result, transitions) =>
result ++ transitions
}
.exists(isEnabledInState(topology, state))
}

def isEnabledInState[S](topology: ExecutablePetriNet[S], state: InstanceState[S])(t: Transition[_, _, _]): Boolean =
def isEnabledInState[S, T <: Transition[_, _, S]](topology: ExecutablePetriNet[S, T], state: InstanceState[S])(
t: T
): Boolean =
t.isAutomated && !state.hasFailed(t.id) && topology.isEnabledInMarking(state.marking.multiplicities)(t)

def takeWhileNotFailed[S](topology: ExecutablePetriNet[S], waitForRetries: Boolean): Any => Boolean = e =>
def takeWhileNotFailed[S, T <: Transition[_, _, S]](
topology: ExecutablePetriNet[S, T],
waitForRetries: Boolean
): Any => Boolean = e =>
e match {
case e: TransitionFired[S] => hasAutomaticTransitions(topology)(e.result)
case TransitionFailed(_, _, _, _, RetryWithDelay(delay)) => waitForRetries
Expand All @@ -75,7 +82,7 @@ object PetriNetInstanceApi {
/**
* Contains some methods to interact with a petri net instance actor.
*/
class PetriNetInstanceApi[S](topology: ExecutablePetriNet[S], actor: ActorRef)(implicit
class PetriNetInstanceApi[S, T <: Transition[_, _, S]](topology: ExecutablePetriNet[S, T], actor: ActorRef)(implicit
actorSystem: ActorSystem,
materializer: Materializer
) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
package io.kagera.akka.actor

import akka.persistence.{ PersistentActor, RecoveryCompleted }
import io.kagera.api.colored.ExecutablePetriNet
import io.kagera.api.colored.{ ExecutablePetriNet, Transition }
import io.kagera.execution.EventSourcing._
import io.kagera.execution.{ EventSourcing, Instance }
import io.kagera.persistence.{ messages, Serialization }

trait PetriNetInstanceRecovery[S] {
trait PetriNetInstanceRecovery[S, T <: Transition[_, _, S]] {

this: PersistentActor =>

def topology: ExecutablePetriNet[S]
def topology: ExecutablePetriNet[S, T]

implicit val system = context.system
val serializer = new Serialization(new AkkaObjectSerializer(context.system))

def onRecoveryCompleted(state: Instance[S])
def onRecoveryCompleted(state: Instance[S, T]): Unit

def applyEvent(i: Instance[S])(e: Event): Instance[S] = EventSourcing.applyEvent(e).runS(i).value
def applyEvent(i: Instance[S, T])(e: Event): Instance[S, T] = EventSourcing.applyEvent[S, T](e).runS(i).value

def persistEvent[T, E <: Event](instance: Instance[S], e: E)(fn: E => T): Unit = {
def persistEvent[R, E <: Event](instance: Instance[S, T], e: E)(fn: E => R): Unit = {
val serializedEvent = serializer.serializeEvent(e)(instance)
persist(serializedEvent) { persisted => fn.apply(e) }
}

private var recoveringState: Instance[S] = Instance.uninitialized[S](topology)
private var recoveringState: Instance[S, T] = Instance.uninitialized[S, T](topology)

private def applyToRecoveringState(e: AnyRef) = {
val deserializedEvent = serializer.deserializeEvent(e)(recoveringState)
Expand Down
12 changes: 6 additions & 6 deletions akka/src/main/scala/io/kagera/akka/query/PetriNetQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,30 @@ import akka.actor.ActorSystem
import akka.persistence.query.scaladsl._
import akka.stream.scaladsl._
import io.kagera.akka.actor.{ AkkaObjectSerializer, PetriNetInstance }
import io.kagera.api.colored.ExecutablePetriNet
import io.kagera.api.colored.{ ExecutablePetriNet, Transition }
import io.kagera.execution.EventSourcing._
import io.kagera.execution.Instance
import io.kagera.persistence.Serialization

trait PetriNetQuery[S] {
trait PetriNetQuery[S, T <: Transition[_, _, S]] {

def readJournal: ReadJournal with CurrentEventsByPersistenceIdQuery

def eventsForInstance(instanceId: String, topology: ExecutablePetriNet[S])(implicit
def eventsForInstance(instanceId: String, topology: ExecutablePetriNet[S, T])(implicit
actorSystem: ActorSystem
): (Source[(Instance[S], Event), NotUsed]) = {
): (Source[(Instance[S, T], Event), NotUsed]) = {

val serializer = new Serialization(new AkkaObjectSerializer(actorSystem))

val persistentId = PetriNetInstance.petriNetInstancePersistenceId(instanceId)
val src = readJournal.currentEventsByPersistenceId(persistentId, 0, Long.MaxValue)

src
.scan[(Instance[S], Event)]((Instance.uninitialized(topology), null.asInstanceOf[Event])) {
.scan[(Instance[S, T], Event)]((Instance.uninitialized(topology), null.asInstanceOf[Event])) {
case ((instance, prev), e) =>
val event = e.event.asInstanceOf[AnyRef]
val deserializedEvent = serializer.deserializeEvent(event)(instance)
val updatedInstance = applyEvent(deserializedEvent).runS(instance).value
val updatedInstance = applyEvent[S, T](deserializedEvent).runS(instance).value
(updatedInstance, deserializedEvent)
}
.drop(1)
Expand Down
99 changes: 55 additions & 44 deletions akka/src/main/scala/io/kagera/persistence/Serialization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,17 @@ class Serialization(serializer: ObjectSerializer) {
* De-serializes a persistence.messages.Event to a EvenSourcing.Event. An Instance is required to 'wire' or
* 'reference' the message back into context.
*/
def deserializeEvent[S](event: AnyRef): Instance[S] => EventSourcing.Event = event match {
case e: messages.Initialized => deserializeInitialized(e)
case e: messages.TransitionFired => deserializeTransitionFired(e)
case e: messages.TransitionFailed => deserializeTransitionFailed(e)
}
def deserializeEvent[S, T <: Transition[_, _, S]](event: AnyRef): Instance[S, T] => EventSourcing.Event =
event match {
case e: messages.Initialized => deserializeInitialized(e)
case e: messages.TransitionFired => deserializeTransitionFired(e)
case e: messages.TransitionFailed => deserializeTransitionFailed(e)
}

/**
* Serializes an EventSourcing.Event to a persistence.messages.Event.
*/
def serializeEvent[S](e: EventSourcing.Event): Instance[S] => AnyRef =
def serializeEvent[S, T <: Transition[_, _, S]](e: EventSourcing.Event): Instance[S, T] => AnyRef =
instance =>
e match {
case e: InitializedEvent => serializeInitialized(e)
Expand All @@ -70,7 +71,10 @@ class Serialization(serializer: ObjectSerializer) {
}
}

private def deserializeProducedMarking[S](instance: Instance[S], produced: Seq[messages.ProducedToken]): Marking = {
private def deserializeProducedMarking[S, T <: Transition[_, _, S]](
instance: Instance[S, T],
produced: Seq[messages.ProducedToken]
): Marking = {
produced.foldLeft(Marking.empty) {
case (accumulated, messages.ProducedToken(Some(placeId), Some(tokenId), Some(count), data, _)) =>
val place = instance.process.places.getById(placeId)
Expand Down Expand Up @@ -104,7 +108,10 @@ class Serialization(serializer: ObjectSerializer) {
}
}

private def deserializeConsumedMarking[S](instance: Instance[S], e: messages.TransitionFired): Marking = {
private def deserializeConsumedMarking[S, T <: Transition[_, _, S]](
instance: Instance[S, T],
e: messages.TransitionFired
): Marking = {
e.consumed.foldLeft(Marking.empty) {
case (accumulated, messages.ConsumedToken(Some(placeId), Some(tokenId), Some(count), _)) =>
val place = instance.marking.keySet.getById(placeId)
Expand All @@ -114,7 +121,9 @@ class Serialization(serializer: ObjectSerializer) {
}
}

private def deserializeInitialized[S](e: messages.Initialized)(instance: Instance[S]): InitializedEvent = {
private def deserializeInitialized[S, T <: Transition[_, _, S]](
e: messages.Initialized
)(instance: Instance[S, T]): InitializedEvent = {
val initialMarking = deserializeProducedMarking(instance, e.initialMarking)
val initialState = e.initialState.map(serializer.deserializeObject).getOrElse(BoxedUnit.UNIT)
InitializedEvent(initialMarking, initialState)
Expand All @@ -126,31 +135,32 @@ class Serialization(serializer: ObjectSerializer) {
messages.Initialized(initialMarking, initialState)
}

private def deserializeTransitionFailed[S](e: messages.TransitionFailed): Instance[S] => TransitionFailedEvent = {
instance =>
val jobId = e.jobId.getOrElse(missingFieldException("job_id"))
val transitionId = e.transitionId.getOrElse(missingFieldException("transition_id"))
val timeStarted = e.timeStarted.getOrElse(missingFieldException("time_started"))
val timeFailed = e.timeFailed.getOrElse(missingFieldException("time_failed"))
val input = e.inputData.map(serializer.deserializeObject)
val failureReason = e.failureReason.getOrElse("")
val failureStrategy = e.failureStrategy.getOrElse(missingFieldException("time_failed")) match {
case FailureStrategy(Some(StrategyType.BLOCK_TRANSITION), _, _) => BlockTransition
case FailureStrategy(Some(StrategyType.BLOCK_ALL), _, _) => Fatal
case FailureStrategy(Some(StrategyType.RETRY), Some(delay), _) => RetryWithDelay(delay)
case other @ _ => throw new IllegalStateException(s"Invalid failure strategy: $other")
}
private def deserializeTransitionFailed[S, T <: Transition[_, _, S]](
e: messages.TransitionFailed
): Instance[S, T] => TransitionFailedEvent = { instance =>
val jobId = e.jobId.getOrElse(missingFieldException("job_id"))
val transitionId = e.transitionId.getOrElse(missingFieldException("transition_id"))
val timeStarted = e.timeStarted.getOrElse(missingFieldException("time_started"))
val timeFailed = e.timeFailed.getOrElse(missingFieldException("time_failed"))
val input = e.inputData.map(serializer.deserializeObject)
val failureReason = e.failureReason.getOrElse("")
val failureStrategy = e.failureStrategy.getOrElse(missingFieldException("time_failed")) match {
case FailureStrategy(Some(StrategyType.BLOCK_TRANSITION), _, _) => BlockTransition
case FailureStrategy(Some(StrategyType.BLOCK_ALL), _, _) => Fatal
case FailureStrategy(Some(StrategyType.RETRY), Some(delay), _) => RetryWithDelay(delay)
case other @ _ => throw new IllegalStateException(s"Invalid failure strategy: $other")
}

TransitionFailedEvent(
jobId,
transitionId,
timeStarted,
timeFailed,
Marking.empty,
None,
failureReason,
failureStrategy
)
TransitionFailedEvent(
jobId,
transitionId,
timeStarted,
timeFailed,
Marking.empty,
None,
failureReason,
failureStrategy
)
}

private def serializeTransitionFailed(e: TransitionFailedEvent): messages.TransitionFailed = {
Expand Down Expand Up @@ -188,19 +198,20 @@ class Serialization(serializer: ObjectSerializer) {
)
}

private def deserializeTransitionFired[S](e: messages.TransitionFired): Instance[S] => TransitionFiredEvent =
instance => {
private def deserializeTransitionFired[S, T <: Transition[_, _, S]](
e: messages.TransitionFired
): Instance[S, T] => TransitionFiredEvent = instance => {

val consumed: Marking = deserializeConsumedMarking(instance, e)
val produced: Marking = deserializeProducedMarking(instance, e.produced)
val consumed: Marking = deserializeConsumedMarking(instance, e)
val produced: Marking = deserializeProducedMarking(instance, e.produced)

val data = e.data.map(serializer.deserializeObject)
val data = e.data.map(serializer.deserializeObject)

val transitionId = e.transitionId.getOrElse(missingFieldException("transition_id"))
val jobId = e.jobId.getOrElse(missingFieldException("job_id"))
val timeStarted = e.timeStarted.getOrElse(missingFieldException("time_started"))
val timeCompleted = e.timeCompleted.getOrElse(missingFieldException("time_completed"))
val transitionId = e.transitionId.getOrElse(missingFieldException("transition_id"))
val jobId = e.jobId.getOrElse(missingFieldException("job_id"))
val timeStarted = e.timeStarted.getOrElse(missingFieldException("time_started"))
val timeCompleted = e.timeCompleted.getOrElse(missingFieldException("time_completed"))

TransitionFiredEvent(jobId, transitionId, timeStarted, timeCompleted, consumed, produced, data)
}
TransitionFiredEvent(jobId, transitionId, timeStarted, timeCompleted, consumed, produced, data)
}
}
Loading