From f4a1b5f3850db13fee6acbfeaf0442d261a04b04 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Oct 2023 16:38:11 +0100 Subject: [PATCH] [Callback] Introduce custom callbacks Signed-off-by: Matteo Bettini --- README.md | 7 +++- benchmarl/experiment/callback.py | 38 +++++++++++++++++++++ benchmarl/experiment/experiment.py | 21 ++++++++++-- examples/callback/custom_callback.py | 50 ++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 benchmarl/experiment/callback.py create mode 100644 examples/callback/custom_callback.py diff --git a/README.md b/README.md index a297c9b8..3598d510 100644 --- a/README.md +++ b/README.md @@ -398,4 +398,9 @@ python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters [![Example](https://img.shields.io/badge/Example-blue.svg)](examples/checkpointing/reload_experiment.py) ### Callbacks -TBC + +Experiments optionally take a list of [`Callback`](benchmarl/experiment/callback.py) which have several methods +that you can implement to see what's going on during training such +as `on_batch_collected`, `on_train_end`, and `on_evaluation_end`. + +[![Example](https://img.shields.io/badge/Example-blue.svg)](examples/callback/custom_callback.py) diff --git a/benchmarl/experiment/callback.py b/benchmarl/experiment/callback.py new file mode 100644 index 00000000..a2bdba1c --- /dev/null +++ b/benchmarl/experiment/callback.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import List + +from tensordict import TensorDictBase + + +class Callback: + def __init__(self): + self.experiment = None + + def on_batch_collected(self, batch: TensorDictBase): + pass + + def on_train_end(self, training_td: TensorDictBase): + pass + + def on_evaluation_end(self, rollouts: List[TensorDictBase]): + pass + + +class CallbackNotifier: + def __init__(self, experiment, callbacks: List[Callback]): + self.callbacks = callbacks + for callback in self.callbacks: + callback.experiment = experiment + + def on_batch_collected(self, batch: TensorDictBase): + for callback in self.callbacks: + callback.on_batch_collected(batch) + + def on_train_end(self, training_td: TensorDictBase): + for callback in self.callbacks: + callback.on_train_end(training_td) + + def on_evaluation_end(self, rollouts: List[TensorDictBase]): + for callback in self.callbacks: + callback.on_evaluation_end(rollouts) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 2d235982..4491af72 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -21,6 +21,8 @@ from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task + +from benchmarl.experiment.callback import Callback, CallbackNotifier from benchmarl.experiment.logger import MultiAgentLogger from benchmarl.models.common import ModelConfig from benchmarl.utils import read_yaml_config @@ -160,7 +162,7 @@ def get_from_yaml(path: Optional[str] = None): return ExperimentConfig(**read_yaml_config(path)) -class Experiment: +class Experiment(CallbackNotifier): def __init__( self, task: Task, @@ -169,7 +171,12 @@ def __init__( seed: int, config: ExperimentConfig, critic_model_config: Optional[ModelConfig] = None, + callbacks: Optional[List[Callback]] = None, ): + super().__init__( + experiment=self, callbacks=callbacks if callbacks is not None else [] + ) + self.config = config self.task = task @@ -407,6 +414,9 @@ def _collection_loop(self): pbar.set_description(f"mean return = {self.mean_return}", refresh=False) pbar.update() + # Callback + self.on_batch_collected(batch) + # Loop over groups training_start = time.time() for group in self.group_map.keys(): @@ -422,11 +432,14 @@ def _collection_loop(self): // self.config.train_minibatch_size(self.on_policy) ): training_tds.append(self._optimizer_loop(group)) - + training_td = torch.stack(training_tds) self.logger.log_training( - group, torch.stack(training_tds), step=self.n_iters_performed + group, training_td, step=self.n_iters_performed ) + # Callback + self.on_train_end(training_td) + # Exploration update if isinstance(self.group_policies[group], TensorDictSequential): explore_layer = self.group_policies[group][-1] @@ -576,6 +589,8 @@ def callback(env, td): step=iter, total_frames=self.total_frames, ) + # Callback + self.on_evaluation_end(rollouts) # Saving trainer state def state_dict(self) -> OrderedDict: diff --git a/examples/callback/custom_callback.py b/examples/callback/custom_callback.py new file mode 100644 index 00000000..2c7cc76e --- /dev/null +++ b/examples/callback/custom_callback.py @@ -0,0 +1,50 @@ +from typing import List + +from benchmarl.algorithms import MappoConfig +from benchmarl.environments import VmasTask +from benchmarl.experiment import Experiment, ExperimentConfig +from benchmarl.experiment.callback import Callback +from benchmarl.models.mlp import MlpConfig +from tensordict import TensorDictBase + + +class MyCallbackA(Callback): + def on_batch_collected(self, batch: TensorDictBase): + print(f"Callback A is doing something with the sampling batch {batch}") + + def on_train_end(self, training_td: TensorDictBase): + print( + f"Callback A is doing something with the training tensordict {training_td}" + ) + + def on_evaluation_end(self, rollouts: List[TensorDictBase]): + print(f"Callback A is doing something with the evaluation rollouts {rollouts}") + + +class MyCallbackB(Callback): + def on_evaluation_end(self, rollouts: List[TensorDictBase]): + print( + "Callback B just reminds you that you fo not need to implement all methods and" + f"you always have access to the experiment {self.experiment} and all its contents" + f"like the policy {self.experiment.policy}" + ) + + +if __name__ == "__main__": + + experiment_config = ExperimentConfig.get_from_yaml() + task = VmasTask.BALANCE.get_from_yaml() + algorithm_config = MappoConfig.get_from_yaml() + model_config = MlpConfig.get_from_yaml() + critic_model_config = MlpConfig.get_from_yaml() + + experiment = Experiment( + task=task, + algorithm_config=algorithm_config, + model_config=model_config, + critic_model_config=critic_model_config, + seed=0, + config=experiment_config, + callbacks=[MyCallbackA(), MyCallbackB()], + ) + experiment.run()