From d57472b4fa2213ec551197ee2e147aef364fdcfe Mon Sep 17 00:00:00 2001 From: Victor Zhestkov Date: Wed, 15 May 2024 11:47:35 +0200 Subject: [PATCH] Prevent OOM with high amount of batch async calls (bsc#1216063) * Refactor batch_async implementation * Fix batch_async tests after refactoring --- salt/cli/batch_async.py | 584 ++++++++++++++------- salt/master.py | 9 +- tests/pytests/unit/cli/test_batch_async.py | 360 +++++++------ 3 files changed, 597 insertions(+), 356 deletions(-) diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py index 1012ce37cca..5d49993faa7 100644 --- a/salt/cli/batch_async.py +++ b/salt/cli/batch_async.py @@ -2,18 +2,193 @@ Execute a job on the targeted minions by using a moving window of fixed size `batch`. """ -import gc - -# pylint: enable=import-error,no-name-in-module,redefined-builtin import logging +import re import salt.client import salt.ext.tornado +import salt.utils.event from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum +from salt.ext.tornado.iostream import StreamClosedError log = logging.getLogger(__name__) +__SHARED_EVENTS_CHANNEL = None + + +def _get_shared_events_channel(opts, io_loop): + global __SHARED_EVENTS_CHANNEL + if __SHARED_EVENTS_CHANNEL is None: + __SHARED_EVENTS_CHANNEL = SharedEventsChannel(opts, io_loop) + return __SHARED_EVENTS_CHANNEL + + +def _destroy_unused_shared_events_channel(): + global __SHARED_EVENTS_CHANNEL + if __SHARED_EVENTS_CHANNEL is not None and __SHARED_EVENTS_CHANNEL.destroy_unused(): + __SHARED_EVENTS_CHANNEL = None + + +def batch_async_required(opts, minions, extra): + """ + Check opts to identify if batch async is required for the operation. + """ + if not isinstance(minions, list): + False + batch_async_opts = opts.get("batch_async", {}) + batch_async_threshold = ( + batch_async_opts.get("threshold", 1) + if isinstance(batch_async_opts, dict) + else 1 + ) + if batch_async_threshold == -1: + batch_size = get_bnum(extra, minions, True) + return len(minions) >= batch_size + elif batch_async_threshold > 0: + return len(minions) >= batch_async_threshold + return False + + +class SharedEventsChannel: + def __init__(self, opts, io_loop): + self.io_loop = io_loop + self.local_client = salt.client.get_local_client( + opts["conf_file"], io_loop=self.io_loop + ) + self.master_event = salt.utils.event.get_event( + "master", + sock_dir=self.local_client.opts["sock_dir"], + opts=self.local_client.opts, + listen=True, + io_loop=self.io_loop, + keep_loop=True, + ) + self.master_event.set_event_handler(self.__handle_event) + if self.master_event.subscriber.stream: + self.master_event.subscriber.stream.set_close_callback(self.__handle_close) + self._re_tag_ret_event = re.compile(r"salt\/job\/(\d+)\/ret\/.*") + self._subscribers = {} + self._subscriptions = {} + self._used_by = set() + batch_async_opts = opts.get("batch_async", {}) + if not isinstance(batch_async_opts, dict): + batch_async_opts = {} + self._subscriber_reconnect_tries = batch_async_opts.get( + "subscriber_reconnect_tries", 5 + ) + self._subscriber_reconnect_interval = batch_async_opts.get( + "subscriber_reconnect_interval", 1.0 + ) + self._reconnecting_subscriber = False + + def subscribe(self, jid, op, subscriber_id, handler): + if subscriber_id not in self._subscribers: + self._subscribers[subscriber_id] = set() + if jid not in self._subscriptions: + self._subscriptions[jid] = [] + self._subscribers[subscriber_id].add(jid) + if (op, subscriber_id, handler) not in self._subscriptions[jid]: + self._subscriptions[jid].append((op, subscriber_id, handler)) + if not self.master_event.subscriber.connected(): + self.__reconnect_subscriber() + + def unsubscribe(self, jid, op, subscriber_id): + if subscriber_id not in self._subscribers: + return + jids = self._subscribers[subscriber_id].copy() + if jid is not None: + jids = set(jid) + for i_jid in jids: + self._subscriptions[i_jid] = list( + filter( + lambda x: not (op in (x[0], None) and x[1] == subscriber_id), + self._subscriptions.get(i_jid, []), + ) + ) + self._subscribers[subscriber_id].discard(i_jid) + self._subscriptions = dict(filter(lambda x: x[1], self._subscriptions.items())) + if not self._subscribers[subscriber_id]: + del self._subscribers[subscriber_id] + + @salt.ext.tornado.gen.coroutine + def __handle_close(self): + if not self._subscriptions: + return + log.warning("Master Event Subscriber was closed. Trying to reconnect...") + yield self.__reconnect_subscriber() + + @salt.ext.tornado.gen.coroutine + def __handle_event(self, raw): + if self.master_event is None: + return + try: + tag, data = self.master_event.unpack(raw) + tag_match = self._re_tag_ret_event.match(tag) + if tag_match: + jid = tag_match.group(1) + if jid in self._subscriptions: + for op, _, handler in self._subscriptions[jid]: + yield handler(tag, data, op) + except Exception as ex: # pylint: disable=W0703 + log.error( + "Exception occured while processing event: %s: %s", + tag, + ex, + exc_info=True, + ) + + @salt.ext.tornado.gen.coroutine + def __reconnect_subscriber(self): + if self.master_event.subscriber.connected() or self._reconnecting_subscriber: + return + self._reconnecting_subscriber = True + max_tries = max(1, int(self._subscriber_reconnect_tries)) + _try = 1 + while _try <= max_tries: + log.info( + "Trying to reconnect to event publisher (try %d of %d) ...", + _try, + max_tries, + ) + try: + yield self.master_event.subscriber.connect() + except StreamClosedError: + log.warning( + "Unable to reconnect to event publisher (try %d of %d)", + _try, + max_tries, + ) + if self.master_event.subscriber.connected(): + self.master_event.subscriber.stream.set_close_callback( + self.__handle_close + ) + log.info("Event publisher connection restored") + self._reconnecting_subscriber = False + return + if _try < max_tries: + yield salt.ext.tornado.gen.sleep(self._subscriber_reconnect_interval) + _try += 1 + self._reconnecting_subscriber = False + + def use(self, subscriber_id): + self._used_by.add(subscriber_id) + return self + + def unuse(self, subscriber_id): + self._used_by.discard(subscriber_id) + + def destroy_unused(self): + if self._used_by: + return False + self.master_event.remove_event_handler(self.__handle_event) + self.master_event.destroy() + self.master_event = None + self.local_client.destroy() + self.local_client = None + return True + + class BatchAsync: """ Run a job on the targeted minions by using a moving window of fixed size `batch`. @@ -28,14 +203,14 @@ class BatchAsync: - gather_job_timeout: `find_job` timeout - timeout: time to wait before firing a `find_job` - When the batch stars, a `start` event is fired: + When the batch starts, a `start` event is fired: - tag: salt/batch//start - data: { "available_minions": self.minions, "down_minions": targeted_minions - presence_ping_minions } - When the batch ends, an `done` event is fired: + When the batch ends, a `done` event is fired: - tag: salt/batch//done - data: { "available_minions": self.minions, @@ -45,17 +220,26 @@ class BatchAsync: } """ - def __init__(self, parent_opts, jid_gen, clear_load): - ioloop = salt.ext.tornado.ioloop.IOLoop.current() - self.local = salt.client.get_local_client( - parent_opts["conf_file"], io_loop=ioloop + def __init__(self, opts, jid_gen, clear_load): + self.extra_job_kwargs = {} + kwargs = clear_load.get("kwargs", {}) + for kwarg in ("module_executors", "executor_opts"): + if kwarg in kwargs: + self.extra_job_kwargs[kwarg] = kwargs[kwarg] + elif kwarg in opts: + self.extra_job_kwargs[kwarg] = opts[kwarg] + self.io_loop = salt.ext.tornado.ioloop.IOLoop.current() + self.events_channel = _get_shared_events_channel(opts, self.io_loop).use( + id(self) ) if "gather_job_timeout" in clear_load["kwargs"]: clear_load["gather_job_timeout"] = clear_load["kwargs"].pop( "gather_job_timeout" ) else: - clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"] + clear_load["gather_job_timeout"] = self.events_channel.local_client.opts[ + "gather_job_timeout" + ] self.batch_presence_ping_timeout = clear_load["kwargs"].get( "batch_presence_ping_timeout", None ) @@ -64,8 +248,8 @@ def __init__(self, parent_opts, jid_gen, clear_load): clear_load.pop("tgt"), clear_load.pop("fun"), clear_load["kwargs"].pop("batch"), - self.local.opts, - **clear_load + self.events_channel.local_client.opts, + **clear_load, ) self.eauth = batch_get_eauth(clear_load["kwargs"]) self.metadata = clear_load["kwargs"].get("metadata", {}) @@ -78,54 +262,45 @@ def __init__(self, parent_opts, jid_gen, clear_load): self.jid_gen = jid_gen self.ping_jid = jid_gen() self.batch_jid = jid_gen() - self.find_job_jid = jid_gen() self.find_job_returned = set() + self.metadata.update({"batch_jid": self.batch_jid, "ping_jid": self.ping_jid}) self.ended = False - self.event = salt.utils.event.get_event( - "master", - self.opts["sock_dir"], - self.opts["transport"], - opts=self.opts, - listen=True, - io_loop=ioloop, - keep_loop=True, - ) + self.event = self.events_channel.master_event self.scheduled = False - self.patterns = set() def __set_event_handler(self): - ping_return_pattern = "salt/job/{}/ret/*".format(self.ping_jid) - batch_return_pattern = "salt/job/{}/ret/*".format(self.batch_jid) - self.event.subscribe(ping_return_pattern, match_type="glob") - self.event.subscribe(batch_return_pattern, match_type="glob") - self.patterns = { - (ping_return_pattern, "ping_return"), - (batch_return_pattern, "batch_run"), - } - self.event.set_event_handler(self.__event_handler) + self.events_channel.subscribe( + self.ping_jid, "ping_return", id(self), self.__event_handler + ) + self.events_channel.subscribe( + self.batch_jid, "batch_run", id(self), self.__event_handler + ) - def __event_handler(self, raw): + @salt.ext.tornado.gen.coroutine + def __event_handler(self, tag, data, op): if not self.event: return try: - mtag, data = self.event.unpack(raw) - for (pattern, op) in self.patterns: - if mtag.startswith(pattern[:-1]): - minion = data["id"] - if op == "ping_return": - self.minions.add(minion) - if self.targeted_minions == self.minions: - self.event.io_loop.spawn_callback(self.start_batch) - elif op == "find_job_return": - if data.get("return", None): - self.find_job_returned.add(minion) - elif op == "batch_run": - if minion in self.active: - self.active.remove(minion) - self.done_minions.add(minion) - self.event.io_loop.spawn_callback(self.schedule_next) - except Exception as ex: - log.error("Exception occured while processing event: {}".format(ex)) + minion = data["id"] + if op == "ping_return": + self.minions.add(minion) + if self.targeted_minions == self.minions: + yield self.start_batch() + elif op == "find_job_return": + if data.get("return", None): + self.find_job_returned.add(minion) + elif op == "batch_run": + if minion in self.active: + self.active.remove(minion) + self.done_minions.add(minion) + yield self.schedule_next() + except Exception as ex: # pylint: disable=W0703 + log.error( + "Exception occured while processing event: %s: %s", + tag, + ex, + exc_info=True, + ) def _get_next(self): to_run = ( @@ -139,176 +314,203 @@ def _get_next(self): ) return set(list(to_run)[:next_batch_size]) + @salt.ext.tornado.gen.coroutine def check_find_job(self, batch_minions, jid): - if self.event: - find_job_return_pattern = "salt/job/{}/ret/*".format(jid) - self.event.unsubscribe(find_job_return_pattern, match_type="glob") - self.patterns.remove((find_job_return_pattern, "find_job_return")) - - timedout_minions = batch_minions.difference( - self.find_job_returned - ).difference(self.done_minions) - self.timedout_minions = self.timedout_minions.union(timedout_minions) - self.active = self.active.difference(self.timedout_minions) - running = batch_minions.difference(self.done_minions).difference( - self.timedout_minions - ) + """ + Check if the job with specified ``jid`` was finished on the minions + """ + if not self.event: + return + self.events_channel.unsubscribe(jid, "find_job_return", id(self)) - if timedout_minions: - self.schedule_next() + timedout_minions = batch_minions.difference(self.find_job_returned).difference( + self.done_minions + ) + self.timedout_minions = self.timedout_minions.union(timedout_minions) + self.active = self.active.difference(self.timedout_minions) + running = batch_minions.difference(self.done_minions).difference( + self.timedout_minions + ) - if self.event and running: - self.find_job_returned = self.find_job_returned.difference(running) - self.event.io_loop.spawn_callback(self.find_job, running) + if timedout_minions: + yield self.schedule_next() + + if self.event and running: + self.find_job_returned = self.find_job_returned.difference(running) + yield self.find_job(running) @salt.ext.tornado.gen.coroutine def find_job(self, minions): - if self.event: - not_done = minions.difference(self.done_minions).difference( - self.timedout_minions + """ + Find if the job was finished on the minions + """ + if not self.event: + return + not_done = minions.difference(self.done_minions).difference( + self.timedout_minions + ) + if not not_done: + return + try: + jid = self.jid_gen() + self.events_channel.subscribe( + jid, "find_job_return", id(self), self.__event_handler ) - try: - if not_done: - jid = self.jid_gen() - find_job_return_pattern = "salt/job/{}/ret/*".format(jid) - self.patterns.add((find_job_return_pattern, "find_job_return")) - self.event.subscribe(find_job_return_pattern, match_type="glob") - ret = yield self.local.run_job_async( - not_done, - "saltutil.find_job", - [self.batch_jid], - "list", - gather_job_timeout=self.opts["gather_job_timeout"], - jid=jid, - **self.eauth - ) - yield salt.ext.tornado.gen.sleep(self.opts["gather_job_timeout"]) - if self.event: - self.event.io_loop.spawn_callback( - self.check_find_job, not_done, jid - ) - except Exception as ex: - log.error( - "Exception occured handling batch async: {}. Aborting execution.".format( - ex - ) - ) - self.close_safe() + ret = yield self.events_channel.local_client.run_job_async( + not_done, + "saltutil.find_job", + [self.batch_jid], + "list", + gather_job_timeout=self.opts["gather_job_timeout"], + jid=jid, + io_loop=self.io_loop, + listen=False, + **self.eauth, + ) + yield salt.ext.tornado.gen.sleep(self.opts["gather_job_timeout"]) + if self.event: + yield self.check_find_job(not_done, jid) + except Exception as ex: # pylint: disable=W0703 + log.error( + "Exception occured handling batch async: %s. Aborting execution.", + ex, + exc_info=True, + ) + self.close_safe() @salt.ext.tornado.gen.coroutine def start(self): + """ + Start the batch execution + """ + if not self.event: + return + self.__set_event_handler() + ping_return = yield self.events_channel.local_client.run_job_async( + self.opts["tgt"], + "test.ping", + [], + self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")), + gather_job_timeout=self.opts["gather_job_timeout"], + jid=self.ping_jid, + metadata=self.metadata, + io_loop=self.io_loop, + listen=False, + **self.eauth, + ) + self.targeted_minions = set(ping_return["minions"]) + # start batching even if not all minions respond to ping + yield salt.ext.tornado.gen.sleep( + self.batch_presence_ping_timeout or self.opts["gather_job_timeout"] + ) if self.event: - self.__set_event_handler() - ping_return = yield self.local.run_job_async( - self.opts["tgt"], - "test.ping", - [], - self.opts.get( - "selected_target_option", self.opts.get("tgt_type", "glob") - ), - gather_job_timeout=self.opts["gather_job_timeout"], - jid=self.ping_jid, - metadata=self.metadata, - **self.eauth - ) - self.targeted_minions = set(ping_return["minions"]) - # start batching even if not all minions respond to ping - yield salt.ext.tornado.gen.sleep( - self.batch_presence_ping_timeout or self.opts["gather_job_timeout"] - ) - if self.event: - self.event.io_loop.spawn_callback(self.start_batch) + yield self.start_batch() @salt.ext.tornado.gen.coroutine def start_batch(self): - if not self.initialized: - self.batch_size = get_bnum(self.opts, self.minions, True) - self.initialized = True - data = { - "available_minions": self.minions, - "down_minions": self.targeted_minions.difference(self.minions), - "metadata": self.metadata, - } - ret = self.event.fire_event( - data, "salt/batch/{}/start".format(self.batch_jid) - ) - if self.event: - self.event.io_loop.spawn_callback(self.run_next) + """ + Fire `salt/batch/*/start` and continue batch with `run_next` + """ + if self.initialized: + return + self.batch_size = get_bnum(self.opts, self.minions, True) + self.initialized = True + data = { + "available_minions": self.minions, + "down_minions": self.targeted_minions.difference(self.minions), + "metadata": self.metadata, + } + yield self.events_channel.master_event.fire_event_async( + data, f"salt/batch/{self.batch_jid}/start" + ) + if self.event: + yield self.run_next() @salt.ext.tornado.gen.coroutine def end_batch(self): + """ + End the batch and call safe closing + """ left = self.minions.symmetric_difference( self.done_minions.union(self.timedout_minions) ) - if not left and not self.ended: - self.ended = True - data = { - "available_minions": self.minions, - "down_minions": self.targeted_minions.difference(self.minions), - "done_minions": self.done_minions, - "timedout_minions": self.timedout_minions, - "metadata": self.metadata, - } - self.event.fire_event(data, "salt/batch/{}/done".format(self.batch_jid)) - - # release to the IOLoop to allow the event to be published - # before closing batch async execution - yield salt.ext.tornado.gen.sleep(1) - self.close_safe() + # Send salt/batch/*/done only if there is nothing to do + # and the event haven't been sent already + if left or self.ended: + return + self.ended = True + data = { + "available_minions": self.minions, + "down_minions": self.targeted_minions.difference(self.minions), + "done_minions": self.done_minions, + "timedout_minions": self.timedout_minions, + "metadata": self.metadata, + } + yield self.events_channel.master_event.fire_event_async( + data, f"salt/batch/{self.batch_jid}/done" + ) + + # release to the IOLoop to allow the event to be published + # before closing batch async execution + yield salt.ext.tornado.gen.sleep(1) + self.close_safe() def close_safe(self): - for (pattern, label) in self.patterns: - self.event.unsubscribe(pattern, match_type="glob") - self.event.remove_event_handler(self.__event_handler) + if self.events_channel is not None: + self.events_channel.unsubscribe(None, None, id(self)) + self.events_channel.unuse(id(self)) + self.events_channel = None + _destroy_unused_shared_events_channel() self.event = None - self.local = None - self.ioloop = None - del self - gc.collect() @salt.ext.tornado.gen.coroutine def schedule_next(self): - if not self.scheduled: - self.scheduled = True - # call later so that we maybe gather more returns - yield salt.ext.tornado.gen.sleep(self.batch_delay) - if self.event: - self.event.io_loop.spawn_callback(self.run_next) + if self.scheduled: + return + self.scheduled = True + # call later so that we maybe gather more returns + yield salt.ext.tornado.gen.sleep(self.batch_delay) + if self.event: + yield self.run_next() @salt.ext.tornado.gen.coroutine def run_next(self): + """ + Continue batch execution with the next targets + """ self.scheduled = False next_batch = self._get_next() - if next_batch: - self.active = self.active.union(next_batch) - try: - ret = yield self.local.run_job_async( - next_batch, - self.opts["fun"], - self.opts["arg"], - "list", - raw=self.opts.get("raw", False), - ret=self.opts.get("return", ""), - gather_job_timeout=self.opts["gather_job_timeout"], - jid=self.batch_jid, - metadata=self.metadata, - ) - - yield salt.ext.tornado.gen.sleep(self.opts["timeout"]) - - # The batch can be done already at this point, which means no self.event - if self.event: - self.event.io_loop.spawn_callback(self.find_job, set(next_batch)) - except Exception as ex: - log.error("Error in scheduling next batch: %s. Aborting execution", ex) - self.active = self.active.difference(next_batch) - self.close_safe() - else: + if not next_batch: yield self.end_batch() - gc.collect() + return + self.active = self.active.union(next_batch) + try: + ret = yield self.events_channel.local_client.run_job_async( + next_batch, + self.opts["fun"], + self.opts["arg"], + "list", + raw=self.opts.get("raw", False), + ret=self.opts.get("return", ""), + gather_job_timeout=self.opts["gather_job_timeout"], + jid=self.batch_jid, + metadata=self.metadata, + io_loop=self.io_loop, + listen=False, + **self.eauth, + **self.extra_job_kwargs, + ) - def __del__(self): - self.local = None - self.event = None - self.ioloop = None - gc.collect() + yield salt.ext.tornado.gen.sleep(self.opts["timeout"]) + + # The batch can be done already at this point, which means no self.event + if self.event: + yield self.find_job(set(next_batch)) + except Exception as ex: # pylint: disable=W0703 + log.error( + "Error in scheduling next batch: %s. Aborting execution", + ex, + exc_info=True, + ) + self.active = self.active.difference(next_batch) + self.close_safe() diff --git a/salt/master.py b/salt/master.py index 425b4121481..d7182d10b5a 100644 --- a/salt/master.py +++ b/salt/master.py @@ -2,6 +2,7 @@ This module contains all of the routines needed to set up a master server, this involves preparing the three listeners and the workers needed by the master. """ + import collections import copy import ctypes @@ -19,7 +20,6 @@ import salt.acl import salt.auth import salt.channel.server -import salt.cli.batch_async import salt.client import salt.client.ssh.client import salt.crypt @@ -55,6 +55,7 @@ import salt.utils.verify import salt.utils.zeromq import salt.wheel +from salt.cli.batch_async import BatchAsync, batch_async_required from salt.config import DEFAULT_INTERVAL from salt.defaults import DEFAULT_TARGET_DELIM from salt.ext.tornado.stack_context import StackContext @@ -2174,9 +2175,9 @@ def get_token(self, clear_load): def publish_batch(self, clear_load, minions, missing): batch_load = {} batch_load.update(clear_load) - batch = salt.cli.batch_async.BatchAsync( + batch = BatchAsync( self.local.opts, - functools.partial(self._prep_jid, clear_load, {}), + lambda: self._prep_jid(clear_load, {}), batch_load, ) ioloop = salt.ext.tornado.ioloop.IOLoop.current() @@ -2331,7 +2332,7 @@ def publish(self, clear_load): ), }, } - if extra.get("batch", None): + if extra.get("batch", None) and batch_async_required(self.opts, minions, extra): return self.publish_batch(clear_load, minions, missing) jid = self._prep_jid(clear_load, extra) diff --git a/tests/pytests/unit/cli/test_batch_async.py b/tests/pytests/unit/cli/test_batch_async.py index e0774ffff34..bc871aba54c 100644 --- a/tests/pytests/unit/cli/test_batch_async.py +++ b/tests/pytests/unit/cli/test_batch_async.py @@ -1,7 +1,7 @@ import pytest import salt.ext.tornado -from salt.cli.batch_async import BatchAsync +from salt.cli.batch_async import BatchAsync, batch_async_required from tests.support.mock import MagicMock, patch @@ -22,16 +22,44 @@ def batch(temp_salt_master): with patch("salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts)): batch = BatchAsync( opts, - MagicMock(side_effect=["1234", "1235", "1236"]), + MagicMock(side_effect=["1234", "1235"]), { "tgt": "", "fun": "", - "kwargs": {"batch": "", "batch_presence_ping_timeout": 1}, + "kwargs": { + "batch": "", + "batch_presence_ping_timeout": 1, + "metadata": {"mykey": "myvalue"}, + }, }, ) yield batch +@pytest.mark.parametrize( + "threshold,minions,batch,expected", + [ + (1, 2, 200, True), + (1, 500, 200, True), + (0, 2, 200, False), + (0, 500, 200, False), + (-1, 2, 200, False), + (-1, 500, 200, True), + (-1, 9, 10, False), + (-1, 11, 10, True), + (10, 9, 8, False), + (10, 9, 10, False), + (10, 11, 8, True), + (10, 11, 10, True), + ], +) +def test_batch_async_required(threshold, minions, batch, expected): + minions_list = [f"minion{i}.example.org" for i in range(minions)] + batch_async_opts = {"batch_async": {"threshold": threshold}} + extra = {"batch": batch} + assert batch_async_required(batch_async_opts, minions_list, extra) == expected + + def test_ping_jid(batch): assert batch.ping_jid == "1234" @@ -40,10 +68,6 @@ def test_batch_jid(batch): assert batch.batch_jid == "1235" -def test_find_job_jid(batch): - assert batch.find_job_jid == "1236" - - def test_batch_size(batch): """ Tests passing batch value as a number @@ -55,58 +79,74 @@ def test_batch_size(batch): def test_batch_start_on_batch_presence_ping_timeout(batch): - # batch_async = BatchAsyncMock(); - batch.event = MagicMock() + future_ret = salt.ext.tornado.gen.Future() + future_ret.set_result({"minions": ["foo", "bar"]}) future = salt.ext.tornado.gen.Future() - future.set_result({"minions": ["foo", "bar"]}) - batch.local.run_job_async.return_value = future - with patch("salt.ext.tornado.gen.sleep", return_value=future): - # ret = batch_async.start(batch) + future.set_result({}) + with patch.object(batch, "events_channel", MagicMock()), patch( + "salt.ext.tornado.gen.sleep", return_value=future + ), patch.object(batch, "start_batch", return_value=future) as start_batch_mock: + batch.events_channel.local_client.run_job_async.return_value = future_ret ret = batch.start() - # assert start_batch is called later with batch_presence_ping_timeout as param - assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,) + # assert start_batch is called + start_batch_mock.assert_called_once() # assert test.ping called - assert batch.local.run_job_async.call_args[0] == ("*", "test.ping", [], "glob") + assert batch.events_channel.local_client.run_job_async.call_args[0] == ( + "*", + "test.ping", + [], + "glob", + ) # assert targeted_minions == all minions matched by tgt assert batch.targeted_minions == {"foo", "bar"} def test_batch_start_on_gather_job_timeout(batch): - # batch_async = BatchAsyncMock(); - batch.event = MagicMock() future = salt.ext.tornado.gen.Future() - future.set_result({"minions": ["foo", "bar"]}) - batch.local.run_job_async.return_value = future + future.set_result({}) + future_ret = salt.ext.tornado.gen.Future() + future_ret.set_result({"minions": ["foo", "bar"]}) batch.batch_presence_ping_timeout = None - with patch("salt.ext.tornado.gen.sleep", return_value=future): + with patch.object(batch, "events_channel", MagicMock()), patch( + "salt.ext.tornado.gen.sleep", return_value=future + ), patch.object( + batch, "start_batch", return_value=future + ) as start_batch_mock, patch.object( + batch, "batch_presence_ping_timeout", None + ): + batch.events_channel.local_client.run_job_async.return_value = future_ret # ret = batch_async.start(batch) ret = batch.start() - # assert start_batch is called later with gather_job_timeout as param - assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,) + # assert start_batch is called + start_batch_mock.assert_called_once() def test_batch_fire_start_event(batch): batch.minions = {"foo", "bar"} batch.opts = {"batch": "2", "timeout": 5} - batch.event = MagicMock() - batch.metadata = {"mykey": "myvalue"} - batch.start_batch() - assert batch.event.fire_event.call_args[0] == ( - { - "available_minions": {"foo", "bar"}, - "down_minions": set(), - "metadata": batch.metadata, - }, - "salt/batch/1235/start", - ) + with patch.object(batch, "events_channel", MagicMock()): + batch.start_batch() + assert batch.events_channel.master_event.fire_event_async.call_args[0] == ( + { + "available_minions": {"foo", "bar"}, + "down_minions": set(), + "metadata": batch.metadata, + }, + "salt/batch/1235/start", + ) def test_start_batch_calls_next(batch): - batch.run_next = MagicMock(return_value=MagicMock()) - batch.event = MagicMock() - batch.start_batch() - assert batch.initialized - assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.run_next,) + batch.initialized = False + future = salt.ext.tornado.gen.Future() + future.set_result({}) + with patch.object(batch, "event", MagicMock()), patch.object( + batch, "events_channel", MagicMock() + ), patch.object(batch, "run_next", return_value=future) as run_next_mock: + batch.events_channel.master_event.fire_event_async.return_value = future + batch.start_batch() + assert batch.initialized + run_next_mock.assert_called_once() def test_batch_fire_done_event(batch): @@ -114,69 +154,52 @@ def test_batch_fire_done_event(batch): batch.minions = {"foo", "bar"} batch.done_minions = {"foo"} batch.timedout_minions = {"bar"} - batch.event = MagicMock() - batch.metadata = {"mykey": "myvalue"} - old_event = batch.event - batch.end_batch() - assert old_event.fire_event.call_args[0] == ( - { - "available_minions": {"foo", "bar"}, - "done_minions": batch.done_minions, - "down_minions": {"baz"}, - "timedout_minions": batch.timedout_minions, - "metadata": batch.metadata, - }, - "salt/batch/1235/done", - ) - - -def test_batch__del__(batch): - batch = BatchAsync(MagicMock(), MagicMock(), MagicMock()) - event = MagicMock() - batch.event = event - batch.__del__() - assert batch.local is None - assert batch.event is None - assert batch.ioloop is None + with patch.object(batch, "events_channel", MagicMock()): + batch.end_batch() + assert batch.events_channel.master_event.fire_event_async.call_args[0] == ( + { + "available_minions": {"foo", "bar"}, + "done_minions": batch.done_minions, + "down_minions": {"baz"}, + "timedout_minions": batch.timedout_minions, + "metadata": batch.metadata, + }, + "salt/batch/1235/done", + ) def test_batch_close_safe(batch): - batch = BatchAsync(MagicMock(), MagicMock(), MagicMock()) - event = MagicMock() - batch.event = event - batch.patterns = { - ("salt/job/1234/ret/*", "find_job_return"), - ("salt/job/4321/ret/*", "find_job_return"), - } - batch.close_safe() - assert batch.local is None - assert batch.event is None - assert batch.ioloop is None - assert len(event.unsubscribe.mock_calls) == 2 - assert len(event.remove_event_handler.mock_calls) == 1 + with patch.object( + batch, "events_channel", MagicMock() + ) as events_channel_mock, patch.object(batch, "event", MagicMock()): + batch.close_safe() + batch.close_safe() + assert batch.events_channel is None + assert batch.event is None + events_channel_mock.unsubscribe.assert_called_once() + events_channel_mock.unuse.assert_called_once() def test_batch_next(batch): - batch.event = MagicMock() batch.opts["fun"] = "my.fun" batch.opts["arg"] = [] - batch._get_next = MagicMock(return_value={"foo", "bar"}) batch.batch_size = 2 future = salt.ext.tornado.gen.Future() - future.set_result({"minions": ["foo", "bar"]}) - batch.local.run_job_async.return_value = future - with patch("salt.ext.tornado.gen.sleep", return_value=future): + future.set_result({}) + with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object( + batch, "events_channel", MagicMock() + ), patch.object(batch, "_get_next", return_value={"foo", "bar"}), patch.object( + batch, "find_job", return_value=future + ) as find_job_mock: + batch.events_channel.local_client.run_job_async.return_value = future batch.run_next() - assert batch.local.run_job_async.call_args[0] == ( + assert batch.events_channel.local_client.run_job_async.call_args[0] == ( {"foo", "bar"}, "my.fun", [], "list", ) - assert batch.event.io_loop.spawn_callback.call_args[0] == ( - batch.find_job, - {"foo", "bar"}, - ) + assert find_job_mock.call_args[0] == ({"foo", "bar"},) assert batch.active == {"bar", "foo"} @@ -239,124 +262,132 @@ def test_next_batch_all_timedout(batch): def test_batch__event_handler_ping_return(batch): batch.targeted_minions = {"foo"} - batch.event = MagicMock( - unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"})) - ) batch.start() assert batch.minions == set() - batch._BatchAsync__event_handler(MagicMock()) + batch._BatchAsync__event_handler( + "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return" + ) assert batch.minions == {"foo"} assert batch.done_minions == set() def test_batch__event_handler_call_start_batch_when_all_pings_return(batch): batch.targeted_minions = {"foo"} - batch.event = MagicMock( - unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"})) - ) - batch.start() - batch._BatchAsync__event_handler(MagicMock()) - assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,) + future = salt.ext.tornado.gen.Future() + future.set_result({}) + with patch.object(batch, "start_batch", return_value=future) as start_batch_mock: + batch.start() + batch._BatchAsync__event_handler( + "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return" + ) + start_batch_mock.assert_called_once() def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(batch): batch.targeted_minions = {"foo", "bar"} - batch.event = MagicMock( - unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"})) - ) - batch.start() - batch._BatchAsync__event_handler(MagicMock()) - assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0 + future = salt.ext.tornado.gen.Future() + future.set_result({}) + with patch.object(batch, "start_batch", return_value=future) as start_batch_mock: + batch.start() + batch._BatchAsync__event_handler( + "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return" + ) + start_batch_mock.assert_not_called() def test_batch__event_handler_batch_run_return(batch): - batch.event = MagicMock( - unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"})) - ) - batch.start() - batch.active = {"foo"} - batch._BatchAsync__event_handler(MagicMock()) - assert batch.active == set() - assert batch.done_minions == {"foo"} - assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.schedule_next,) + future = salt.ext.tornado.gen.Future() + future.set_result({}) + with patch.object( + batch, "schedule_next", return_value=future + ) as schedule_next_mock: + batch.start() + batch.active = {"foo"} + batch._BatchAsync__event_handler( + "salt/job/1235/ret/foo", {"id": "foo"}, "batch_run" + ) + assert batch.active == set() + assert batch.done_minions == {"foo"} + schedule_next_mock.assert_called_once() def test_batch__event_handler_find_job_return(batch): - batch.event = MagicMock( - unpack=MagicMock( - return_value=( - "salt/job/1236/ret/foo", - {"id": "foo", "return": "deadbeaf"}, - ) - ) - ) batch.start() - batch.patterns.add(("salt/job/1236/ret/*", "find_job_return")) - batch._BatchAsync__event_handler(MagicMock()) + batch._BatchAsync__event_handler( + "salt/job/1236/ret/foo", {"id": "foo", "return": "deadbeaf"}, "find_job_return" + ) assert batch.find_job_returned == {"foo"} def test_batch_run_next_end_batch_when_no_next(batch): - batch.end_batch = MagicMock() - batch._get_next = MagicMock(return_value={}) - batch.run_next() - assert len(batch.end_batch.mock_calls) == 1 + future = salt.ext.tornado.gen.Future() + future.set_result({}) + with patch.object( + batch, "_get_next", return_value={} + ), patch.object( + batch, "end_batch", return_value=future + ) as end_batch_mock: + batch.run_next() + end_batch_mock.assert_called_once() def test_batch_find_job(batch): - batch.event = MagicMock() future = salt.ext.tornado.gen.Future() future.set_result({}) - batch.local.run_job_async.return_value = future batch.minions = {"foo", "bar"} - batch.jid_gen = MagicMock(return_value="1234") - with patch("salt.ext.tornado.gen.sleep", return_value=future): + with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object( + batch, "check_find_job", return_value=future + ) as check_find_job_mock, patch.object( + batch, "jid_gen", return_value="1236" + ): + batch.events_channel.local_client.run_job_async.return_value = future batch.find_job({"foo", "bar"}) - assert batch.event.io_loop.spawn_callback.call_args[0] == ( - batch.check_find_job, + assert check_find_job_mock.call_args[0] == ( {"foo", "bar"}, - "1234", + "1236", ) def test_batch_find_job_with_done_minions(batch): batch.done_minions = {"bar"} - batch.event = MagicMock() future = salt.ext.tornado.gen.Future() future.set_result({}) - batch.local.run_job_async.return_value = future batch.minions = {"foo", "bar"} - batch.jid_gen = MagicMock(return_value="1234") - with patch("salt.ext.tornado.gen.sleep", return_value=future): + with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object( + batch, "check_find_job", return_value=future + ) as check_find_job_mock, patch.object( + batch, "jid_gen", return_value="1236" + ): + batch.events_channel.local_client.run_job_async.return_value = future batch.find_job({"foo", "bar"}) - assert batch.event.io_loop.spawn_callback.call_args[0] == ( - batch.check_find_job, + assert check_find_job_mock.call_args[0] == ( {"foo"}, - "1234", + "1236", ) def test_batch_check_find_job_did_not_return(batch): - batch.event = MagicMock() batch.active = {"foo"} batch.find_job_returned = set() - batch.patterns = {("salt/job/1234/ret/*", "find_job_return")} - batch.check_find_job({"foo"}, jid="1234") - assert batch.find_job_returned == set() - assert batch.active == set() - assert len(batch.event.io_loop.add_callback.mock_calls) == 0 + future = salt.ext.tornado.gen.Future() + future.set_result({}) + with patch.object(batch, "find_job", return_value=future) as find_job_mock: + batch.check_find_job({"foo"}, jid="1234") + assert batch.find_job_returned == set() + assert batch.active == set() + find_job_mock.assert_not_called() def test_batch_check_find_job_did_return(batch): - batch.event = MagicMock() batch.find_job_returned = {"foo"} - batch.patterns = {("salt/job/1234/ret/*", "find_job_return")} - batch.check_find_job({"foo"}, jid="1234") - assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"}) + future = salt.ext.tornado.gen.Future() + future.set_result({}) + with patch.object(batch, "find_job", return_value=future) as find_job_mock: + batch.check_find_job({"foo"}, jid="1234") + find_job_mock.assert_called_once_with({"foo"}) def test_batch_check_find_job_multiple_states(batch): - batch.event = MagicMock() # currently running minions batch.active = {"foo", "bar"} @@ -372,21 +403,28 @@ def test_batch_check_find_job_multiple_states(batch): # both not yet done but only 'foo' responded to find_job not_done = {"foo", "bar"} - batch.patterns = {("salt/job/1234/ret/*", "find_job_return")} - batch.check_find_job(not_done, jid="1234") + future = salt.ext.tornado.gen.Future() + future.set_result({}) - # assert 'bar' removed from active - assert batch.active == {"foo"} + with patch.object(batch, "schedule_next", return_value=future), patch.object( + batch, "find_job", return_value=future + ) as find_job_mock: + batch.check_find_job(not_done, jid="1234") - # assert 'bar' added to timedout_minions - assert batch.timedout_minions == {"bar", "faz"} + # assert 'bar' removed from active + assert batch.active == {"foo"} - # assert 'find_job' schedueled again only for 'foo' - assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"}) + # assert 'bar' added to timedout_minions + assert batch.timedout_minions == {"bar", "faz"} + + # assert 'find_job' schedueled again only for 'foo' + find_job_mock.assert_called_once_with({"foo"}) def test_only_on_run_next_is_scheduled(batch): - batch.event = MagicMock() + future = salt.ext.tornado.gen.Future() + future.set_result({}) batch.scheduled = True - batch.schedule_next() - assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0 + with patch.object(batch, "run_next", return_value=future) as run_next_mock: + batch.schedule_next() + run_next_mock.assert_not_called()