diff --git a/colmena/queue/base.py b/colmena/queue/base.py index 178f405..5535522 100644 --- a/colmena/queue/base.py +++ b/colmena/queue/base.py @@ -200,6 +200,9 @@ def send_inputs(self, """ self._check_role(QueueRole.CLIENT, 'send_inputs') + # Make sure the queue topic exists + assert topic in self.topics, f'Unknown topic: {topic}. Known are: {", ".join(self.topics)}' + # Make fake kwargs, if needed if input_kwargs is None: input_kwargs = dict() diff --git a/colmena/task_server/base.py b/colmena/task_server/base.py index 77dbd1f..414f999 100644 --- a/colmena/task_server/base.py +++ b/colmena/task_server/base.py @@ -7,7 +7,7 @@ from inspect import signature from multiprocessing import Process from time import perf_counter -from typing import Optional, Callable +from typing import Optional, Callable, Collection from colmena.exceptions import KillSignalException, TimeoutException from colmena.models import Result, FailureInformation @@ -37,15 +37,16 @@ class BaseTaskServer(Process, metaclass=ABCMeta): Implementations should also provide a `_cleanup` function that releases any resources reserved by the task server. """ - def __init__(self, queues: ColmenaQueues, timeout: Optional[int] = None): + def __init__(self, queues: ColmenaQueues, method_names: Collection[str], timeout: Optional[int] = None): """ Args: queues (TaskServerQueues): Queues for the task server - timeout (int): Timeout, if desired + timeout (int): Timeout for reading from the task queue, if desired """ super().__init__() self.queues = queues self.timeout = timeout + self.method_names = set(method_names) @abstractmethod def process_queue(self, topic: str, task: Result): @@ -65,8 +66,17 @@ def listen_and_launch(self): topic, task = self.queues.get_task(self.timeout) logger.info(f'Received request for {task.method} with topic {topic}') - # Provide it to the workflow system to be executed - self.process_queue(topic, task) + # Make sure the method name is valid + if task.method in self.method_names: + # Provide it to the workflow system to be executed + self.process_queue(topic, task) + else: + task.success = False + task.failure_info = FailureInformation.from_exception( + ValueError(f'Method name "{task.method}" not recognized. Options: {", ".join(self.method_names)}') + ) + self.queues.send_result(task, topic) + except KillSignalException: logger.info('Kill signal received') return diff --git a/colmena/task_server/funcx.py b/colmena/task_server/funcx.py index 665f32a..dae2600 100644 --- a/colmena/task_server/funcx.py +++ b/colmena/task_server/funcx.py @@ -49,8 +49,6 @@ def __init__(self, methods: Dict[Callable, str], timeout: Timeout for requests from the task queue batch_size: Maximum number of task request to receive before submitting """ - super(FuncXTaskServer, self).__init__(queues, timeout) - # Store the client that has already been authenticated. self.fx_client = funcx_client self.fx_exec: FuncXExecutor = None @@ -70,6 +68,9 @@ def __init__(self, methods: Dict[Callable, str], batch_size=batch_size, ) + # Initialize the outputs + super().__init__(queues, self.registered_funcs.keys(), timeout) + def perform_callback(self, future: Future, result: Result, topic: str): # Check if the failure was due to a ManagerLost # TODO (wardlt): Remove when we have retry support in FuncX diff --git a/colmena/task_server/parsl.py b/colmena/task_server/parsl.py index 7a09831..e9c1c63 100644 --- a/colmena/task_server/parsl.py +++ b/colmena/task_server/parsl.py @@ -330,10 +330,8 @@ def __init__(self, methods: List[Union[Callable, Tuple[Callable, Dict]]], timeout (int): Timeout, if desired default_executors: Executor or list of executors to use by default. """ - super().__init__(queues, timeout) - # Insert _output_workers to the thread count - executors = config.executors.copy() + executors = list(config.executors) config.executors = executors # Get a list of default executors that _does not_ include the output workers @@ -394,6 +392,9 @@ def __init__(self, methods: List[Union[Callable, Tuple[Callable, Dict]]], if self.default_method_ is not None: logger.info(f'There is only one method, so we are using {self.default_method_} as a default') + # Initialize the base class + super().__init__(queues, self.methods_.keys(), timeout) + def _submit(self, task: Result, topic: str) -> Optional[Future]: # Determine which method to run if self.default_method_ and task.method is None: diff --git a/colmena/task_server/tests/test_parsl.py b/colmena/task_server/tests/test_parsl.py index 6939335..876887a 100644 --- a/colmena/task_server/tests/test_parsl.py +++ b/colmena/task_server/tests/test_parsl.py @@ -214,3 +214,18 @@ def test_proxy(server_and_queue, store): result = queue.get_result() assert result.success, result.failure_info.exception assert len(result.proxy_timing) == 3 + + +@mark.timeout(10) +def test_bad_method_name(server_and_queue): + """Make sure tasks with undefined methods are returned with a meaningful error""" + + # Start the server + server, queue = server_and_queue + server.start() + + # Make sure it sends back a result + queue.send_inputs(1, method='not_a_real_method') + result = queue.get_result() + assert not result.success + assert 'not_a_real_method' in str(result.failure_info.exception)