Skip to content

Commit

Permalink
[Callback] Introduce custom callbacks
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 6, 2023
1 parent b32c0a2 commit f4a1b5f
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 4 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,4 +398,9 @@ python benchmarl/run.py task=vmas/balance algorithm=mappo experiment.max_n_iters
[![Example](https://img.shields.io/badge/Example-blue.svg)](examples/checkpointing/reload_experiment.py)

### Callbacks
TBC

Experiments optionally take a list of [`Callback`](benchmarl/experiment/callback.py) which have several methods
that you can implement to see what's going on during training such
as `on_batch_collected`, `on_train_end`, and `on_evaluation_end`.

[![Example](https://img.shields.io/badge/Example-blue.svg)](examples/callback/custom_callback.py)
38 changes: 38 additions & 0 deletions benchmarl/experiment/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from typing import List

from tensordict import TensorDictBase


class Callback:
def __init__(self):
self.experiment = None

def on_batch_collected(self, batch: TensorDictBase):
pass

def on_train_end(self, training_td: TensorDictBase):
pass

def on_evaluation_end(self, rollouts: List[TensorDictBase]):
pass


class CallbackNotifier:
def __init__(self, experiment, callbacks: List[Callback]):
self.callbacks = callbacks
for callback in self.callbacks:
callback.experiment = experiment

def on_batch_collected(self, batch: TensorDictBase):
for callback in self.callbacks:
callback.on_batch_collected(batch)

def on_train_end(self, training_td: TensorDictBase):
for callback in self.callbacks:
callback.on_train_end(training_td)

def on_evaluation_end(self, rollouts: List[TensorDictBase]):
for callback in self.callbacks:
callback.on_evaluation_end(rollouts)
21 changes: 18 additions & 3 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task

from benchmarl.experiment.callback import Callback, CallbackNotifier
from benchmarl.experiment.logger import MultiAgentLogger
from benchmarl.models.common import ModelConfig
from benchmarl.utils import read_yaml_config
Expand Down Expand Up @@ -160,7 +162,7 @@ def get_from_yaml(path: Optional[str] = None):
return ExperimentConfig(**read_yaml_config(path))


class Experiment:
class Experiment(CallbackNotifier):
def __init__(
self,
task: Task,
Expand All @@ -169,7 +171,12 @@ def __init__(
seed: int,
config: ExperimentConfig,
critic_model_config: Optional[ModelConfig] = None,
callbacks: Optional[List[Callback]] = None,
):
super().__init__(
experiment=self, callbacks=callbacks if callbacks is not None else []
)

self.config = config

self.task = task
Expand Down Expand Up @@ -407,6 +414,9 @@ def _collection_loop(self):
pbar.set_description(f"mean return = {self.mean_return}", refresh=False)
pbar.update()

# Callback
self.on_batch_collected(batch)

# Loop over groups
training_start = time.time()
for group in self.group_map.keys():
Expand All @@ -422,11 +432,14 @@ def _collection_loop(self):
// self.config.train_minibatch_size(self.on_policy)
):
training_tds.append(self._optimizer_loop(group))

training_td = torch.stack(training_tds)
self.logger.log_training(
group, torch.stack(training_tds), step=self.n_iters_performed
group, training_td, step=self.n_iters_performed
)

# Callback
self.on_train_end(training_td)

# Exploration update
if isinstance(self.group_policies[group], TensorDictSequential):
explore_layer = self.group_policies[group][-1]
Expand Down Expand Up @@ -576,6 +589,8 @@ def callback(env, td):
step=iter,
total_frames=self.total_frames,
)
# Callback
self.on_evaluation_end(rollouts)

# Saving trainer state
def state_dict(self) -> OrderedDict:
Expand Down
50 changes: 50 additions & 0 deletions examples/callback/custom_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List

from benchmarl.algorithms import MappoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.experiment.callback import Callback
from benchmarl.models.mlp import MlpConfig
from tensordict import TensorDictBase


class MyCallbackA(Callback):
def on_batch_collected(self, batch: TensorDictBase):
print(f"Callback A is doing something with the sampling batch {batch}")

def on_train_end(self, training_td: TensorDictBase):
print(
f"Callback A is doing something with the training tensordict {training_td}"
)

def on_evaluation_end(self, rollouts: List[TensorDictBase]):
print(f"Callback A is doing something with the evaluation rollouts {rollouts}")


class MyCallbackB(Callback):
def on_evaluation_end(self, rollouts: List[TensorDictBase]):
print(
"Callback B just reminds you that you fo not need to implement all methods and"
f"you always have access to the experiment {self.experiment} and all its contents"
f"like the policy {self.experiment.policy}"
)


if __name__ == "__main__":

experiment_config = ExperimentConfig.get_from_yaml()
task = VmasTask.BALANCE.get_from_yaml()
algorithm_config = MappoConfig.get_from_yaml()
model_config = MlpConfig.get_from_yaml()
critic_model_config = MlpConfig.get_from_yaml()

experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
callbacks=[MyCallbackA(), MyCallbackB()],
)
experiment.run()

0 comments on commit f4a1b5f

Please sign in to comment.