diff --git a/src/aiida/engine/processes/calcjobs/tasks.py b/src/aiida/engine/processes/calcjobs/tasks.py index b1ea8c482c..ca8dbf2d72 100644 --- a/src/aiida/engine/processes/calcjobs/tasks.py +++ b/src/aiida/engine/processes/calcjobs/tasks.py @@ -54,6 +54,23 @@ class PreSubmitException(Exception): # noqa: N818 """Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`.""" +async def get_transport(authinfo, transport_queue, cancellable): + transport_requests = transport_queue._transport_requests + last_transport_request = transport_requests.get(authinfo.pk, None) + + # ? Refactor this into `obtain_transport` function + # ? Returns last transport if open, and awaits close callback handle, otherwise request new transport + if last_transport_request is None or transport_queue._last_request_special: + # This is the previous behavior + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + else: + transport = authinfo.get_transport() + if not transport.is_open: + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + else: + transport_queue._last_request_special = True async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, cancellable: InterruptableFuture): """Transport task that will attempt to upload the files of a job calculation to the remote. @@ -143,9 +160,11 @@ async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, ca authinfo = node.get_authinfo() async def do_submit(): - with transport_queue.request_transport(authinfo) as request: - transport = await cancellable.with_interrupt(request) - return execmanager.submit_calculation(node, transport) + + transport = get_transport(authinfo=authinfo, transport_queue=transport_queue, cancellable=cancellable) + print('a') + + return execmanager.submit_calculation(node, transport) try: logger.info(f'scheduled request to submit CalcJob<{node.pk}>') diff --git a/src/aiida/engine/transports.py b/src/aiida/engine/transports.py index e5311e3830..24b3d97435 100644 --- a/src/aiida/engine/transports.py +++ b/src/aiida/engine/transports.py @@ -29,12 +29,15 @@ class TransportRequest: """Information kept about request for a transport object""" def __init__(self): - super().__init__() self.future: asyncio.Future = asyncio.Future() self.count = 0 - # ? What do I need this for? - # self.transport_closer = None +class TransportCloseRequest: + """Information kept about close request for a transport object""" + + def __init__(self): + self.future: asyncio.Future = asyncio.Future() + self.count = 0 class TransportQueue: """A queue to get transport objects from authinfo. This class allows clients @@ -54,7 +57,7 @@ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): self._last_open_time = None self._last_close_time = None self._last_request_special: bool = False - # self._last_submission_time = None + self._close_callback_handle = None # self._last_transport_request: Dict[Hashable, str] = {} @property @@ -77,8 +80,10 @@ async def transport_task(transport_queue, authinfo): :return: A future that can be yielded to give the transport """ open_callback_handle = None - # close_callback_handle = None + close_callback_handle = None transport_request = self._transport_requests.get(authinfo.pk, None) + # safe_open_interval = transport.get_safe_open_interval() + safe_open_interval = 30 if transport_request is None: # There is no existing request for this transport (i.e. on this authinfo) @@ -86,8 +91,6 @@ async def transport_task(transport_queue, authinfo): self._transport_requests[authinfo.pk] = transport_request transport = authinfo.get_transport() - # safe_open_interval = transport.get_safe_open_interval() - safe_open_interval = 30 # Check here if last_open_time > safe_interval, one could immediately open the transport # This should be the very first request, after a while @@ -116,30 +119,36 @@ def do_open(): # First request, submit immediately # ? Are these attributes persistet, or is a new TransportQueue instance created for every transport task? - if self._last_close_time is None: - open_callback_handle = self._loop.call_soon(do_open, context=contextvars.Context()) - self._last_request_special = True - - elif self._last_request_special: + if self._last_request_special: open_callback_handle = self._loop.call_later(safe_open_interval, do_open, context=contextvars.Context()) self._last_request_special = False + elif self._last_close_time is None: + open_callback_handle = self._loop.call_soon(do_open, context=contextvars.Context()) + self._last_request_special = True + else: - timedelta_seconds = (timezone.localtime(timezone.now()) - self._last_close_time).total_seconds() + close_timedelta = (timezone.localtime(timezone.now()) - self._last_close_time).total_seconds() + open_timedelta = (timezone.localtime(timezone.now()) - self._last_open_time).total_seconds() - if timedelta_seconds > safe_open_interval: + if open_timedelta > safe_open_interval: # ! This could also be `_loop.call_soon` which has an implicit delay of 0s - open_timedelta = timedelta_seconds-safe_open_interval - open_callback_handle = self._loop.call_later(open_timedelta, do_open, context=contextvars.Context()) + # open_timedelta = close_timedelta-safe_open_interval + open_callback_handle = self._loop.call_soon(do_open, context=contextvars.Context()) self._last_request_special = True else: - # If the last one was a special request, wait the safe_open_interval - open_callback_handle = self._loop.call_later(safe_open_interval, do_open, context=contextvars.Context()) + # If the last one was a special request, wait the difference between safe_open_interval and lost + open_callback_handle = self._loop.call_later(safe_open_interval-open_timedelta, do_open, context=contextvars.Context()) # open_callback_handle = self._loop.call_later(safe_open_interval, do_open, context=contextvars.Context()) + # ? This logic is implemented in `tasks.py` instead. + # else: + # transport = authinfo.get_transport() + # return transport + # If transport_request is open already try: transport_request.count += 1 yield transport_request.future @@ -161,7 +170,7 @@ def do_open(): def do_close(): """Close the transport if conditions are met.""" transport_request.future.result().close() - # self._last_close_time = timezone.localtime(timezone.now()) + self._last_close_time = timezone.localtime(timezone.now()) close_timedelta = (timezone.localtime(timezone.now()) - self._last_open_time).total_seconds() @@ -170,16 +179,21 @@ def do_close(): # Also here logic when transport should be closed immediately, or when via call_later? close_callback_handle = self._loop.call_soon(do_close, context=contextvars.Context()) self._last_close_time = timezone.localtime(timezone.now()) + self._transport_requests.pop(authinfo.pk, None) else: close_callback_handle = self._loop.call_later(safe_open_interval, do_close, context=contextvars.Context()) + self._transport_requests.pop(authinfo.pk, None) # transport_request.transport_closer = close_callback_handle # This should be replaced with the call_later close_callback_handle invocation # transport_request.future.result().close() - + # ? When should the transport_request be popped? + # ? If it is always popped as soon as the task is done, there is no way to re-use it... + # self._transport_requests.pop(authinfo.pk, None) elif open_callback_handle is not None: open_callback_handle.cancel() - self._transport_requests.pop(authinfo.pk, None) + # ? Somewhere I still need to `pop` the transport_request... or do I? + # self._transport_requests.pop(authinfo.pk, None) diff --git a/src/aiida/engine/utils.py b/src/aiida/engine/utils.py index 4053156a97..88fc9e80f9 100644 --- a/src/aiida/engine/utils.py +++ b/src/aiida/engine/utils.py @@ -198,6 +198,7 @@ async def exponential_backoff_retry( result: Any = None coro = ensure_coroutine(fct) + print('a') interval = initial_interval for iteration in range(max_attempts):