Skip to content

Commit

Permalink
Quality of Life Improvements (#94)
Browse files Browse the repository at this point in the history
* Better error messages if unknown topic

* Return a meaningful message if method undefined
  • Loading branch information
WardLT authored Mar 3, 2023
1 parent bd334e0 commit 95f3222
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 10 deletions.
3 changes: 3 additions & 0 deletions colmena/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def send_inputs(self,
"""
self._check_role(QueueRole.CLIENT, 'send_inputs')

# Make sure the queue topic exists
assert topic in self.topics, f'Unknown topic: {topic}. Known are: {", ".join(self.topics)}'

# Make fake kwargs, if needed
if input_kwargs is None:
input_kwargs = dict()
Expand Down
20 changes: 15 additions & 5 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from inspect import signature
from multiprocessing import Process
from time import perf_counter
from typing import Optional, Callable
from typing import Optional, Callable, Collection

from colmena.exceptions import KillSignalException, TimeoutException
from colmena.models import Result, FailureInformation
Expand Down Expand Up @@ -37,15 +37,16 @@ class BaseTaskServer(Process, metaclass=ABCMeta):
Implementations should also provide a `_cleanup` function that releases any resources reserved by the task server.
"""

def __init__(self, queues: ColmenaQueues, timeout: Optional[int] = None):
def __init__(self, queues: ColmenaQueues, method_names: Collection[str], timeout: Optional[int] = None):
"""
Args:
queues (TaskServerQueues): Queues for the task server
timeout (int): Timeout, if desired
timeout (int): Timeout for reading from the task queue, if desired
"""
super().__init__()
self.queues = queues
self.timeout = timeout
self.method_names = set(method_names)

@abstractmethod
def process_queue(self, topic: str, task: Result):
Expand All @@ -65,8 +66,17 @@ def listen_and_launch(self):
topic, task = self.queues.get_task(self.timeout)
logger.info(f'Received request for {task.method} with topic {topic}')

# Provide it to the workflow system to be executed
self.process_queue(topic, task)
# Make sure the method name is valid
if task.method in self.method_names:
# Provide it to the workflow system to be executed
self.process_queue(topic, task)
else:
task.success = False
task.failure_info = FailureInformation.from_exception(
ValueError(f'Method name "{task.method}" not recognized. Options: {", ".join(self.method_names)}')
)
self.queues.send_result(task, topic)

except KillSignalException:
logger.info('Kill signal received')
return
Expand Down
5 changes: 3 additions & 2 deletions colmena/task_server/funcx.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def __init__(self, methods: Dict[Callable, str],
timeout: Timeout for requests from the task queue
batch_size: Maximum number of task request to receive before submitting
"""
super(FuncXTaskServer, self).__init__(queues, timeout)

# Store the client that has already been authenticated.
self.fx_client = funcx_client
self.fx_exec: FuncXExecutor = None
Expand All @@ -70,6 +68,9 @@ def __init__(self, methods: Dict[Callable, str],
batch_size=batch_size,
)

# Initialize the outputs
super().__init__(queues, self.registered_funcs.keys(), timeout)

def perform_callback(self, future: Future, result: Result, topic: str):
# Check if the failure was due to a ManagerLost
# TODO (wardlt): Remove when we have retry support in FuncX
Expand Down
7 changes: 4 additions & 3 deletions colmena/task_server/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,8 @@ def __init__(self, methods: List[Union[Callable, Tuple[Callable, Dict]]],
timeout (int): Timeout, if desired
default_executors: Executor or list of executors to use by default.
"""
super().__init__(queues, timeout)

# Insert _output_workers to the thread count
executors = config.executors.copy()
executors = list(config.executors)
config.executors = executors

# Get a list of default executors that _does not_ include the output workers
Expand Down Expand Up @@ -394,6 +392,9 @@ def __init__(self, methods: List[Union[Callable, Tuple[Callable, Dict]]],
if self.default_method_ is not None:
logger.info(f'There is only one method, so we are using {self.default_method_} as a default')

# Initialize the base class
super().__init__(queues, self.methods_.keys(), timeout)

def _submit(self, task: Result, topic: str) -> Optional[Future]:
# Determine which method to run
if self.default_method_ and task.method is None:
Expand Down
15 changes: 15 additions & 0 deletions colmena/task_server/tests/test_parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,18 @@ def test_proxy(server_and_queue, store):
result = queue.get_result()
assert result.success, result.failure_info.exception
assert len(result.proxy_timing) == 3


@mark.timeout(10)
def test_bad_method_name(server_and_queue):
"""Make sure tasks with undefined methods are returned with a meaningful error"""

# Start the server
server, queue = server_and_queue
server.start()

# Make sure it sends back a result
queue.send_inputs(1, method='not_a_real_method')
result = queue.get_result()
assert not result.success
assert 'not_a_real_method' in str(result.failure_info.exception)

0 comments on commit 95f3222

Please sign in to comment.