From 42e07a927dfa86e22c287fb8f78ea9e5805bd8f8 Mon Sep 17 00:00:00 2001 From: gofrendi Date: Fri, 29 Nov 2024 06:08:49 +0700 Subject: [PATCH] Add append methods --- src/zrb/task/any_task.py | 38 +++++++++++++++++++++- src/zrb/task/base_task.py | 67 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/src/zrb/task/any_task.py b/src/zrb/task/any_task.py index 242727e2..565e9c3d 100644 --- a/src/zrb/task/any_task.py +++ b/src/zrb/task/any_task.py @@ -74,15 +74,51 @@ def fallbacks(self) -> list["AnyTask"]: """Task fallbacks""" pass + @property + @abstractmethod + def successors(self) -> list["AnyTask"]: + """Task successors""" + pass + @property @abstractmethod def readiness_checks(self) -> list["AnyTask"]: """Task readiness checks""" pass + @abstractmethod + def append_fallbacks(self, fallbacks: "AnyTask" | list["AnyTask"]): + """Add the fallback tasks that this task depends on. + + Args: + fallbacks (AnyTask | list[AnyTask]): A single fallback task or + a list of fallback tasks. + """ + pass + + @abstractmethod + def append_successors(self, successors: "AnyTask" | list["AnyTask"]): + """Add the successor tasks that this task depends on. + + Args: + successors (AnyTask | list[AnyTask]): A single successor task or + a list of successor tasks. + """ + pass + + @abstractmethod + def append_readiness_checks(self, readiness_checks: "AnyTask" | list["AnyTask"]): + """Add the readiness_check tasks that this task depends on. + + Args: + readiness_checks (AnyTask | list[AnyTask]): A single readiness_check task or + a list of readiness_check tasks. + """ + pass + @abstractmethod def append_upstreams(self, upstreams: "AnyTask" | list["AnyTask"]): - """Sets the upstream tasks that this task depends on. + """Add the upstream tasks that this task depends on. Args: upstreams (AnyTask | list[AnyTask]): A single upstream task or diff --git a/src/zrb/task/base_task.py b/src/zrb/task/base_task.py index 24441dc5..5ccc14d9 100644 --- a/src/zrb/task/base_task.py +++ b/src/zrb/task/base_task.py @@ -38,6 +38,7 @@ def __init__( monitor_readiness: bool = False, upstream: list[AnyTask] | AnyTask | None = None, fallback: list[AnyTask] | AnyTask | None = None, + successor: list[AnyTask] | AnyTask | None = None, ): self._name = name self._color = color @@ -50,6 +51,7 @@ def __init__( self._retry_period = retry_period self._upstreams = upstream self._fallbacks = fallback + self._successors = successor self._readiness_checks = readiness_check self._readiness_check_delay = readiness_check_delay self._readiness_check_period = readiness_check_period @@ -142,6 +144,44 @@ def fallbacks(self) -> list[AnyTask]: return [self._fallbacks] return self._fallbacks + def append_fallbacks(self, fallbacks: AnyTask | list[AnyTask]): + fallback_list = [fallbacks] if isinstance(fallbacks, AnyTask) else fallbacks + for fallback in fallback_list: + self.__append_fallback(fallback) + + def __append_fallback(self, fallback: AnyTask): + # Make sure self._fallbacks is a list + if self._fallbacks is None: + self._fallbacks = [] + elif isinstance(self._fallbacks, AnyTask): + self._fallbacks = [self._fallbacks] + # Add fallback if it was not on self._fallbacks + if fallback not in self._fallbacks: + self._fallbacks.append(fallback) + + @property + def successors(self) -> list[AnyTask]: + if self._successors is None: + return [] + elif isinstance(self._successors, AnyTask): + return [self._successors] + return self._successors + + def append_successors(self, successors: AnyTask | list[AnyTask]): + successor_list = [successors] if isinstance(successors, AnyTask) else successors + for successor in successor_list: + self.__append_successor(successor) + + def __append_successor(self, successor: AnyTask): + # Make sure self._successors is a list + if self._successors is None: + self._successors = [] + elif isinstance(self._successors, AnyTask): + self._successors = [self._successors] + # Add successor if it was not on self._successors + if successor not in self._successors: + self._successors.append(successor) + @property def readiness_checks(self) -> list[AnyTask]: if self._readiness_checks is None: @@ -150,6 +190,25 @@ def readiness_checks(self) -> list[AnyTask]: return [self._readiness_checks] return self._readiness_checks + def append_readiness_checks(self, readiness_checks: AnyTask | list[AnyTask]): + readiness_check_list = ( + [readiness_checks] + if isinstance(readiness_checks, AnyTask) + else readiness_checks + ) + for readiness_check in readiness_check_list: + self.__append_readiness_check(readiness_check) + + def __append_readiness_check(self, readiness_check: AnyTask): + # Make sure self._readiness_checks is a list + if self._readiness_checks is None: + self._readiness_checks = [] + elif isinstance(self._readiness_checks, AnyTask): + self._readiness_checks = [self._readiness_checks] + # Add readiness_check if it was not on self._readiness_checks + if readiness_check not in self._readiness_checks: + self._readiness_checks.append(readiness_check) + @property def upstreams(self) -> list[AnyTask]: if self._upstreams is None: @@ -374,6 +433,7 @@ async def __exec_action_and_retry(self, session: AnySession) -> Any: # Put result on xcom task_xcom: Xcom = ctx.xcom.get(self.name) task_xcom.push(result) + await run_async(self.__exec_successors(session)) return result except (asyncio.CancelledError, KeyboardInterrupt): ctx.log_info("Marked as failed") @@ -390,6 +450,13 @@ async def __exec_action_and_retry(self, session: AnySession) -> Any: await run_async(self.__exec_fallbacks(session)) raise e + async def __exec_successors(self, session: AnySession) -> Any: + successors: list[AnyTask] = self.successors + successor_coros = [ + run_async(successor.exec_chain(session)) for successor in successors + ] + await asyncio.gather(*successor_coros) + async def __exec_fallbacks(self, session: AnySession) -> Any: fallbacks: list[AnyTask] = self.fallbacks fallback_coros = [