Skip to content

Commit

Permalink
introduce inject methods
Browse files Browse the repository at this point in the history
  • Loading branch information
goFrendiAsgard committed Nov 9, 2023
1 parent c151f0c commit 5bc4e41
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 29 deletions.
7 changes: 1 addition & 6 deletions docs/tutorials/extending-cmd-task.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,8 @@ class SlackPrintTask(CmdTask):
self._slack_channel_id = slack_channel_id
self._slack_app_token = slack_app_token
self._message = message
self._slack_env_added = False

def get_envs(self) -> List[Env]:
if self._slack_env_added:
return super().get_env()
self._slack_env_added = True
def inject_envs(self):
self.add_envs(
Env(
name='CHANNEL_ID', os_name='',
Expand All @@ -75,7 +71,6 @@ class SlackPrintTask(CmdTask):
default=self.render_str(self._message)
)
)
return super().get_env()

def get_cmd_script(self, *args: Any, **kwargs: Any):
# contruct json payload and replace all `"` with `\\"`
Expand Down
20 changes: 20 additions & 0 deletions src/zrb/task/any_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,42 @@ def get_cmd_name(self) -> str:
def get_full_cmd_name(self) -> str:
pass

@abstractmethod
def inject_env_files(self):
pass

@abstractmethod
def _get_env_files(self) -> List[EnvFile]:
pass

@abstractmethod
def inject_envs(self):
pass

@abstractmethod
def _get_envs(self) -> List[Env]:
pass

@abstractmethod
def inject_inputs(self):
pass

@abstractmethod
def _get_inputs(self) -> List[AnyInput]:
pass

@abstractmethod
def inject_checkers(self):
pass

@abstractmethod
def _get_checkers(self) -> Iterable[TAnyTask]:
pass

@abstractmethod
def inject_upstreams(self):
pass

@abstractmethod
def _get_upstreams(self) -> Iterable[TAnyTask]:
pass
Expand Down
10 changes: 3 additions & 7 deletions src/zrb/task/base_remote_cmd_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from zrb.helper.typing import (
Any, Callable, Iterable, List, Mapping, Optional, Union, TypeVar
Any, Callable, Iterable, Mapping, Optional, Union, TypeVar
)
from zrb.helper.typecheck import typechecked
from zrb.helper.util import to_snake_case
Expand Down Expand Up @@ -122,15 +122,12 @@ def __init__(
self._post_cmd = post_cmd
self._post_cmd_path = post_cmd_path
self._remote_config = remote_config
self._is_additional_env_added = False

def copy(self) -> TSingleBaseRemoteCmdTask:
return copy.deepcopy(self)

def _get_envs(self) -> List[Env]:
if self._is_additional_env_added:
return super()._get_envs()
self._is_additional_env_added = True
def inject_envs(self):
super().inject_envs()
# add remote config properties as env
self.add_env(
Env(
Expand Down Expand Up @@ -165,7 +162,6 @@ def _get_envs(self) -> List[Env]:
default=rendered_val
)
)
return super()._get_envs()

def get_cmd_script(self, *args: Any, **kwargs: Any) -> str:
cmd_str = '\n'.join([
Expand Down
35 changes: 35 additions & 0 deletions src/zrb/task/base_task_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def __init__(
self._allow_add_env_files = True
self._allow_add_inputs = True
self._allow_add_upstreams: bool = True
self._has_already_inject_env_files: bool = False
self._has_already_inject_envs: bool = False
self._has_already_inject_inputs: bool = False
self._has_already_inject_checkers: bool = False
self._has_already_inject_upstreams: bool = False
self._execution_id = ''

def _set_execution_id(self, execution_id: str):
Expand Down Expand Up @@ -150,19 +155,49 @@ def get_icon(self) -> str:
def get_color(self) -> str:
return self._color

def inject_env_files(self):
pass

def _get_env_files(self) -> List[EnvFile]:
if not self._has_already_inject_env_files:
self.inject_env_files()
self._has_already_inject_env_files = True
return self._env_files

def inject_envs(self):
pass

def _get_envs(self) -> List[Env]:
if not self._has_already_inject_envs:
self.inject_envs()
self._has_already_inject_envs = True
return list(self._envs)

def inject_inputs(self):
pass

def _get_inputs(self) -> List[AnyInput]:
if not self._has_already_inject_inputs:
self.inject_inputs()
self._has_already_inject_inputs = True
return list(self._inputs)

def inject_checkers(self):
pass

def _get_checkers(self) -> List[AnyTask]:
if not self._has_already_inject_checkers:
self.inject_checkers()
self._has_already_inject_checkers = True
return list(self._checkers)

def inject_upstreams(self):
pass

def _get_upstreams(self) -> List[AnyTask]:
if not self._has_already_inject_upstreams:
self.inject_upstreams()
self._has_already_inject_upstreams = True
return list(self._upstreams)

def get_description(self) -> str:
Expand Down
8 changes: 2 additions & 6 deletions src/zrb/task/cmd_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def __init__(
self._executable = executable
self._process: Optional[asyncio.subprocess.Process]
self._preexec_fn = preexec_fn
self._is_cmd_aditional_env_added = False

def copy(self) -> TCmdTask:
return super().copy()
Expand All @@ -183,17 +182,14 @@ def print_result(self, result: CmdResult):
return
print(result.output)

def _get_envs(self) -> List[Env]:
if self._is_cmd_aditional_env_added:
return super()._get_envs()
self._is_cmd_aditional_env_added = True
def inject_envs(self):
super().inject_envs()
input_map = self.get_input_map()
for input_name, input_value in input_map.items():
env_name = '_INPUT_' + input_name.upper()
self.add_env(
Env(name=env_name, os_name='', default=str(input_value))
)
return super()._get_envs()

async def run(self, *args: Any, **kwargs: Any) -> CmdResult:
cmd = self.get_cmd_script(*args, **kwargs)
Expand Down
13 changes: 3 additions & 10 deletions src/zrb/task/docker_compose_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,8 @@ async def run(self, *args, **kwargs: Any) -> CmdResult:
os.remove(self._compose_runtime_file)
return result

def _get_envs(self) -> List[Env]:
if self._is_compose_additional_env_added:
return super()._get_envs()
self._is_compose_additional_env_added = True
def inject_envs(self):
super().inject_envs()
# inject envs from service_configs
for _, service_config in self._compose_service_configs.items():
self.insert_env(*service_config.get_envs())
Expand All @@ -188,16 +186,11 @@ def _get_envs(self) -> List[Env]:
if self._compose_env_prefix != '':
os_name = f'{self._compose_env_prefix}_{os_name}'
self.insert_env(Env(name=key, os_name=os_name, default=value))
return super()._get_envs()

def _get_env_files(self) -> List[EnvFile]:
if self._is_compose_additional_env_file_added:
return super().get_env_file()
self._is_compose_additional_env_file_added = True
def inject_env_files(self):
# inject env_files from service_configs
for _, service_config in self._compose_service_configs.items():
self.insert_env_file(*service_config.get_env_files())
return super()._get_env_files()

def _generate_compose_runtime_file(self):
compose_data = read_compose_file(self._compose_template_file)
Expand Down

0 comments on commit 5bc4e41

Please sign in to comment.