From a6e8fd93ce483e405e70497cd6294d3af3f32c3f Mon Sep 17 00:00:00 2001 From: emailweixu Date: Wed, 4 Sep 2024 14:47:48 -0700 Subject: [PATCH] Stop async evaluation immediately for Evaluator.close() (#1695) * Stop async evaluation immediately for Evaluator.close() When error happens in the training, ALF will try to close all the evaluator. However, some evaluation takes long time. This cause the main process just wait and we cannot see the error message from exception. With this change, the evaluation will stop immediately so we can see the error messages. * PeekableQueue * remove debug print * Remove unused code * Comment --- alf/trainers/evaluator.py | 76 +++++++++++++++++++++++++++++++--- alf/trainers/policy_trainer.py | 3 ++ 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/alf/trainers/evaluator.py b/alf/trainers/evaluator.py index 48bc60beb..dd6367900 100644 --- a/alf/trainers/evaluator.py +++ b/alf/trainers/evaluator.py @@ -113,6 +113,7 @@ def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]): self._evaluator.eval(algorithm, step_metric_values) def close(self): + """Stop the ongoing evaluation and close the evaluator.""" if self._async: job = EvalJob(type="stop") self._job_queue.put(job) @@ -120,6 +121,13 @@ def close(self): else: self._env.close() + def wait_complete(self): + """Wait until evaluation is complete.""" + if self._async: + job = EvalJob(type="wait") + self._job_queue.put(job) + self._done_queue.get() + def _define_flags(): flags.DEFINE_string('gin_file', None, 'Path to the gin-config file.') @@ -131,6 +139,40 @@ def _define_flags(): FLAGS = flags.FLAGS +class PeekableQueue(object): + """A queue that supports peeking the first element without removing it. + + Note that this can only be used for a queue with one consumer. + """ + + def __init__(self, queue: mp.Queue): + self._queue = queue + self._elements = [] # elements that are peeked but not removed + + def peek(self): + """Peek the first element in the queue without removing it. + + Returns: + The first element in the queue. ``None`` if the queue is empty. + """ + if len(self._elements) == 0: + if not self._queue.empty(): + self._elements.append(self._queue.get()) + if len(self._elements) > 0: + return self._elements[0] + else: + return None + + def get(self): + if len(self._elements) == 0: + return self._queue.get() + else: + return self._elements.pop(0) + + def empty(self): + return len(self._elements) == 0 and self._queue.empty() + + class SyncEvaluator(object): """Evaluator for performing evaluation on the current algorithm. @@ -145,7 +187,10 @@ def __init__(self, env, config): self._summary_writer = alf.summary.create_summary_writer( eval_dir, flush_secs=config.summaries_flush_secs) - def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]): + def eval(self, + algorithm: RLAlgorithm, + step_metric_values: Dict[str, int], + job_queue: Optional[PeekableQueue] = None): """Do one round of evaluation. This function will return after finishing the evaluation. @@ -158,11 +203,17 @@ def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]): step_metric_values: a dictionary of step metric values to generate the evaluation summaries against. Note that it needs to contain "EnvironmentSteps" at least. + job_queue: This is only used when `eval()` is called from a worker + process. If during the evaluation, the worker receives a "stop" + job from the main process, it will stop the evaluation and + return immediately. """ with alf.summary.push_summary_writer(self._summary_writer): logging.info("Start evaluation") metrics = evaluate(self._env, algorithm, - self._config.num_eval_episodes) + self._config.num_eval_episodes, job_queue) + if metrics is None: + return common.log_metrics(metrics) for metric in metrics: metric.gen_summaries( @@ -285,6 +336,7 @@ def _worker(job_queue: mp.Queue, config.num_env_steps) alf.summary.enable_summary() evaluator = SyncEvaluator(env, config) + job_queue = PeekableQueue(job_queue) logging.info("Evaluator started") while True: job = job_queue.get() @@ -298,9 +350,11 @@ def _worker(job_queue: mp.Queue, job.global_counter, env_steps) algorithm.load_state_dict(job.state_dict) done_queue.put(None) - evaluator.eval(algorithm, job.step_metrics) + evaluator.eval(algorithm, job.step_metrics, job_queue) elif job.type == "stop": break + elif job.type == "wait": + done_queue.put(None) else: raise KeyError('Received message of unknown type {}'.format( job.type)) @@ -314,14 +368,21 @@ def _worker(job_queue: mp.Queue, @common.mark_eval -def evaluate(env: AlfEnvironment, algorithm: RLAlgorithm, - num_episodes: int) -> List[alf.metrics.StepMetric]: +def evaluate(env: AlfEnvironment, + algorithm: RLAlgorithm, + num_episodes: int, + job_queue: Optional[PeekableQueue] = None + ) -> List[alf.metrics.StepMetric]: """Perform one round of evaluation. Args: env: the environment algorithm: the training algorithm num_episodes: number of episodes to evaluate + job_queue: This is only used when `eval()` is called from a worker + process. If during the evaluation, the worker receives a "stop" + job from the main process, it will stop the evaluation and + return immediately. Returns: a list of metrics from the evaluation """ @@ -379,6 +440,11 @@ def evaluate(env: AlfEnvironment, algorithm: RLAlgorithm, policy_state = policy_step.state time_step = next_time_step + if job_queue is not None: + job = job_queue.peek() + if job is not None and job.type == "stop": + logging.info("Received stop signal. Aborting evaluation.") + return None env.reset() return metrics diff --git a/alf/trainers/policy_trainer.py b/alf/trainers/policy_trainer.py index 5ccb3be5b..72fd6d1df 100644 --- a/alf/trainers/policy_trainer.py +++ b/alf/trainers/policy_trainer.py @@ -702,6 +702,9 @@ def _train(self): self._save_checkpoint() self._checkpoint_requested = False + if self._evaluate: + self._evaluator.wait_complete() + def _need_to_evaluate(self, iter_num): if not self._evaluate: return False