Skip to content

Commit

Permalink
Add append methods
Browse files Browse the repository at this point in the history
  • Loading branch information
goFrendiAsgard committed Nov 28, 2024
1 parent 02c9ce3 commit 42e07a9
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/zrb/task/any_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,51 @@ def fallbacks(self) -> list["AnyTask"]:
"""Task fallbacks"""
pass

@property
@abstractmethod
def successors(self) -> list["AnyTask"]:
"""Task successors"""
pass

@property
@abstractmethod
def readiness_checks(self) -> list["AnyTask"]:
"""Task readiness checks"""
pass

@abstractmethod
def append_fallbacks(self, fallbacks: "AnyTask" | list["AnyTask"]):
"""Add the fallback tasks that this task depends on.
Args:
fallbacks (AnyTask | list[AnyTask]): A single fallback task or
a list of fallback tasks.
"""
pass

@abstractmethod
def append_successors(self, successors: "AnyTask" | list["AnyTask"]):
"""Add the successor tasks that this task depends on.
Args:
successors (AnyTask | list[AnyTask]): A single successor task or
a list of successor tasks.
"""
pass

@abstractmethod
def append_readiness_checks(self, readiness_checks: "AnyTask" | list["AnyTask"]):
"""Add the readiness_check tasks that this task depends on.
Args:
readiness_checks (AnyTask | list[AnyTask]): A single readiness_check task or
a list of readiness_check tasks.
"""
pass

@abstractmethod
def append_upstreams(self, upstreams: "AnyTask" | list["AnyTask"]):
"""Sets the upstream tasks that this task depends on.
"""Add the upstream tasks that this task depends on.
Args:
upstreams (AnyTask | list[AnyTask]): A single upstream task or
Expand Down
67 changes: 67 additions & 0 deletions src/zrb/task/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
monitor_readiness: bool = False,
upstream: list[AnyTask] | AnyTask | None = None,
fallback: list[AnyTask] | AnyTask | None = None,
successor: list[AnyTask] | AnyTask | None = None,
):
self._name = name
self._color = color
Expand All @@ -50,6 +51,7 @@ def __init__(
self._retry_period = retry_period
self._upstreams = upstream
self._fallbacks = fallback
self._successors = successor
self._readiness_checks = readiness_check
self._readiness_check_delay = readiness_check_delay
self._readiness_check_period = readiness_check_period
Expand Down Expand Up @@ -142,6 +144,44 @@ def fallbacks(self) -> list[AnyTask]:
return [self._fallbacks]
return self._fallbacks

def append_fallbacks(self, fallbacks: AnyTask | list[AnyTask]):
fallback_list = [fallbacks] if isinstance(fallbacks, AnyTask) else fallbacks
for fallback in fallback_list:
self.__append_fallback(fallback)

def __append_fallback(self, fallback: AnyTask):
# Make sure self._fallbacks is a list
if self._fallbacks is None:
self._fallbacks = []
elif isinstance(self._fallbacks, AnyTask):
self._fallbacks = [self._fallbacks]
# Add fallback if it was not on self._fallbacks
if fallback not in self._fallbacks:
self._fallbacks.append(fallback)

@property
def successors(self) -> list[AnyTask]:
if self._successors is None:
return []
elif isinstance(self._successors, AnyTask):
return [self._successors]
return self._successors

def append_successors(self, successors: AnyTask | list[AnyTask]):
successor_list = [successors] if isinstance(successors, AnyTask) else successors
for successor in successor_list:
self.__append_successor(successor)

def __append_successor(self, successor: AnyTask):
# Make sure self._successors is a list
if self._successors is None:
self._successors = []
elif isinstance(self._successors, AnyTask):
self._successors = [self._successors]
# Add successor if it was not on self._successors
if successor not in self._successors:
self._successors.append(successor)

@property
def readiness_checks(self) -> list[AnyTask]:
if self._readiness_checks is None:
Expand All @@ -150,6 +190,25 @@ def readiness_checks(self) -> list[AnyTask]:
return [self._readiness_checks]
return self._readiness_checks

def append_readiness_checks(self, readiness_checks: AnyTask | list[AnyTask]):
readiness_check_list = (
[readiness_checks]
if isinstance(readiness_checks, AnyTask)
else readiness_checks
)
for readiness_check in readiness_check_list:
self.__append_readiness_check(readiness_check)

def __append_readiness_check(self, readiness_check: AnyTask):
# Make sure self._readiness_checks is a list
if self._readiness_checks is None:
self._readiness_checks = []
elif isinstance(self._readiness_checks, AnyTask):
self._readiness_checks = [self._readiness_checks]
# Add readiness_check if it was not on self._readiness_checks
if readiness_check not in self._readiness_checks:
self._readiness_checks.append(readiness_check)

@property
def upstreams(self) -> list[AnyTask]:
if self._upstreams is None:
Expand Down Expand Up @@ -374,6 +433,7 @@ async def __exec_action_and_retry(self, session: AnySession) -> Any:
# Put result on xcom
task_xcom: Xcom = ctx.xcom.get(self.name)
task_xcom.push(result)
await run_async(self.__exec_successors(session))
return result
except (asyncio.CancelledError, KeyboardInterrupt):
ctx.log_info("Marked as failed")
Expand All @@ -390,6 +450,13 @@ async def __exec_action_and_retry(self, session: AnySession) -> Any:
await run_async(self.__exec_fallbacks(session))
raise e

async def __exec_successors(self, session: AnySession) -> Any:
successors: list[AnyTask] = self.successors
successor_coros = [
run_async(successor.exec_chain(session)) for successor in successors
]
await asyncio.gather(*successor_coros)

async def __exec_fallbacks(self, session: AnySession) -> Any:
fallbacks: list[AnyTask] = self.fallbacks
fallback_coros = [
Expand Down

0 comments on commit 42e07a9

Please sign in to comment.