Skip to content

Commit

Permalink
Stop async evaluation immediately for Evaluator.close() (#1695)
Browse files Browse the repository at this point in the history
* Stop async evaluation immediately for Evaluator.close()

When error happens in the training, ALF will try to close all the evaluator.
However, some evaluation takes long time. This cause the main process just
wait and we cannot see the error message from exception.

With this change, the evaluation will stop immediately so we can
see the error messages.

* PeekableQueue

* remove debug print

* Remove unused code

* Comment
  • Loading branch information
emailweixu authored Sep 4, 2024
1 parent 6d56e66 commit a6e8fd9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
76 changes: 71 additions & 5 deletions alf/trainers/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,21 @@ def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]):
self._evaluator.eval(algorithm, step_metric_values)

def close(self):
"""Stop the ongoing evaluation and close the evaluator."""
if self._async:
job = EvalJob(type="stop")
self._job_queue.put(job)
self._worker.join()
else:
self._env.close()

def wait_complete(self):
"""Wait until evaluation is complete."""
if self._async:
job = EvalJob(type="wait")
self._job_queue.put(job)
self._done_queue.get()


def _define_flags():
flags.DEFINE_string('gin_file', None, 'Path to the gin-config file.')
Expand All @@ -131,6 +139,40 @@ def _define_flags():
FLAGS = flags.FLAGS


class PeekableQueue(object):
"""A queue that supports peeking the first element without removing it.
Note that this can only be used for a queue with one consumer.
"""

def __init__(self, queue: mp.Queue):
self._queue = queue
self._elements = [] # elements that are peeked but not removed

def peek(self):
"""Peek the first element in the queue without removing it.
Returns:
The first element in the queue. ``None`` if the queue is empty.
"""
if len(self._elements) == 0:
if not self._queue.empty():
self._elements.append(self._queue.get())
if len(self._elements) > 0:
return self._elements[0]
else:
return None

def get(self):
if len(self._elements) == 0:
return self._queue.get()
else:
return self._elements.pop(0)

def empty(self):
return len(self._elements) == 0 and self._queue.empty()


class SyncEvaluator(object):
"""Evaluator for performing evaluation on the current algorithm.
Expand All @@ -145,7 +187,10 @@ def __init__(self, env, config):
self._summary_writer = alf.summary.create_summary_writer(
eval_dir, flush_secs=config.summaries_flush_secs)

def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]):
def eval(self,
algorithm: RLAlgorithm,
step_metric_values: Dict[str, int],
job_queue: Optional[PeekableQueue] = None):
"""Do one round of evaluation.
This function will return after finishing the evaluation.
Expand All @@ -158,11 +203,17 @@ def eval(self, algorithm: RLAlgorithm, step_metric_values: Dict[str, int]):
step_metric_values: a dictionary of step metric values to generate
the evaluation summaries against. Note that it needs to contain
"EnvironmentSteps" at least.
job_queue: This is only used when `eval()` is called from a worker
process. If during the evaluation, the worker receives a "stop"
job from the main process, it will stop the evaluation and
return immediately.
"""
with alf.summary.push_summary_writer(self._summary_writer):
logging.info("Start evaluation")
metrics = evaluate(self._env, algorithm,
self._config.num_eval_episodes)
self._config.num_eval_episodes, job_queue)
if metrics is None:
return
common.log_metrics(metrics)
for metric in metrics:
metric.gen_summaries(
Expand Down Expand Up @@ -285,6 +336,7 @@ def _worker(job_queue: mp.Queue,
config.num_env_steps)
alf.summary.enable_summary()
evaluator = SyncEvaluator(env, config)
job_queue = PeekableQueue(job_queue)
logging.info("Evaluator started")
while True:
job = job_queue.get()
Expand All @@ -298,9 +350,11 @@ def _worker(job_queue: mp.Queue,
job.global_counter, env_steps)
algorithm.load_state_dict(job.state_dict)
done_queue.put(None)
evaluator.eval(algorithm, job.step_metrics)
evaluator.eval(algorithm, job.step_metrics, job_queue)
elif job.type == "stop":
break
elif job.type == "wait":
done_queue.put(None)
else:
raise KeyError('Received message of unknown type {}'.format(
job.type))
Expand All @@ -314,14 +368,21 @@ def _worker(job_queue: mp.Queue,


@common.mark_eval
def evaluate(env: AlfEnvironment, algorithm: RLAlgorithm,
num_episodes: int) -> List[alf.metrics.StepMetric]:
def evaluate(env: AlfEnvironment,
algorithm: RLAlgorithm,
num_episodes: int,
job_queue: Optional[PeekableQueue] = None
) -> List[alf.metrics.StepMetric]:
"""Perform one round of evaluation.
Args:
env: the environment
algorithm: the training algorithm
num_episodes: number of episodes to evaluate
job_queue: This is only used when `eval()` is called from a worker
process. If during the evaluation, the worker receives a "stop"
job from the main process, it will stop the evaluation and
return immediately.
Returns:
a list of metrics from the evaluation
"""
Expand Down Expand Up @@ -379,6 +440,11 @@ def evaluate(env: AlfEnvironment, algorithm: RLAlgorithm,

policy_state = policy_step.state
time_step = next_time_step
if job_queue is not None:
job = job_queue.peek()
if job is not None and job.type == "stop":
logging.info("Received stop signal. Aborting evaluation.")
return None

env.reset()
return metrics
3 changes: 3 additions & 0 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,9 @@ def _train(self):
self._save_checkpoint()
self._checkpoint_requested = False

if self._evaluate:
self._evaluator.wait_complete()

def _need_to_evaluate(self, iter_num):
if not self._evaluate:
return False
Expand Down

0 comments on commit a6e8fd9

Please sign in to comment.