diff --git a/src/bos/operators/base.py b/src/bos/operators/base.py index 9d27addb..29d4d7bd 100644 --- a/src/bos/operators/base.py +++ b/src/bos/operators/base.py @@ -27,6 +27,7 @@ """ from abc import ABC, abstractmethod +from contextlib import ExitStack import itertools import logging import threading @@ -55,6 +56,41 @@ class MissingSessionData(BaseOperatorException): desired state. """ +class ApiClients: + """ + Context manager to provide API clients to BOS operators. + Essentially, it uses an ExitStack context manager to manage the API clients. + """ + + def __init__(self): + #self.bos = BOSClient() + #self.bss = BSSClient() + #self.cfs = CFSClient() + #self.hsm = HSMClient() + #self.ims = IMSClient() + #self.pcs = PCSClient() + self._stack = ExitStack() + + def __enter__(self): + """ + Enter context for all API clients + """ + #self._stack.enter_context(self.bos) + #self._stack.enter_context(self.bss) + #self._stack.enter_context(self.cfs) + #self._stack.enter_context(self.hsm) + #self._stack.enter_context(self.ims) + #self._stack.enter_context(self.pcs) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit context on the exit stack, which will take care of exiting + context for all of the API clients. + """ + return self._stack.__exit__(exc_type, exc_val, exc_tb) + + class BaseOperator(ABC): """ @@ -76,6 +112,17 @@ class BaseOperator(ABC): def __init__(self) -> NoReturn: self.bos_client = BOSClient() self.__max_batch_size = 0 + self._client: ApiClients | None = None + + @property + def client(self) -> ApiClients: + """ + Return the ApiClients object for this operator. + If it is not initialized, raise a ValueError (this should never be the case). + """ + if self._client is None: + raise ValueError("Attempted to access uninitialized API client") + return self._client @property @abstractmethod @@ -98,9 +145,15 @@ def run(self) -> NoReturn: try: options.update() _update_log_level() - self._run() + with ApiClients() as _client: + self._client = _client + self._run() except Exception as e: LOGGER.exception('Unhandled exception detected: %s', e) + finally: + # We have exited the context manager, so make sure to reset the client + # value for this operator + self._client = None try: sleep_time = getattr(options, self.frequency_option) - (