From 731a3ff9ff72552d5397022025bdb1e91e77550a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Oct 2023 18:04:33 +0100 Subject: [PATCH] [Example] Custom task and algorithm Signed-off-by: Matteo Bettini --- examples/extending/algorithm/README.md | 18 ++ .../extending/algorithm/custom_algorithm.py | 254 ++++++++++++++++++ .../algorithm/customiqlalgorithm.yaml | 7 + .../extending/{ => model}/custom_model.py | 0 examples/extending/task/README.md | 16 ++ examples/extending/task/custom_task.py | 90 +++++++ 6 files changed, 385 insertions(+) create mode 100644 examples/extending/algorithm/README.md create mode 100644 examples/extending/algorithm/custom_algorithm.py create mode 100644 examples/extending/algorithm/customiqlalgorithm.yaml rename examples/extending/{ => model}/custom_model.py (100%) create mode 100644 examples/extending/task/README.md create mode 100644 examples/extending/task/custom_task.py diff --git a/examples/extending/algorithm/README.md b/examples/extending/algorithm/README.md new file mode 100644 index 00000000..3cf4323a --- /dev/null +++ b/examples/extending/algorithm/README.md @@ -0,0 +1,18 @@ + +# Creating a new algorithm + +Here are the steps to create a new algorithm. You can find the custom IQL algorithm +created for this example in [`custom_agorithm.py`](custom_algorithm.py). + +1. Create your `CustomAlgorithm` and `CustomAlgorithmConfig` following the example +in [`custom_agorithm.py`](custom_algorithm.py). These will be the algorithm code +and a associated dataclass to validate loaded configs. +2. Create a `customalgorithm.yaml` with the configuration parameters you defined +in your script. Make sure it has `customalgorithm_config` within its defaults at +the top of the file to let hydra know which python dataclass it is +associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml) +for an example +3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and +your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to +override from) +4. Add `{"custom_agorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithm.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py) diff --git a/examples/extending/algorithm/custom_algorithm.py b/examples/extending/algorithm/custom_algorithm.py new file mode 100644 index 00000000..79ca4bb8 --- /dev/null +++ b/examples/extending/algorithm/custom_algorithm.py @@ -0,0 +1,254 @@ +from dataclasses import dataclass, MISSING +from typing import Dict, Iterable, Optional, Tuple, Type + +from benchmarl.algorithms.common import Algorithm, AlgorithmConfig +from benchmarl.models.common import ModelConfig +from benchmarl.utils import read_yaml_config + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModule, TensorDictSequential +from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.modules import EGreedyModule, QValueModule +from torchrl.objectives import DQNLoss, LossModule, ValueEstimators + + +class CustomIqlAlgorithm(Algorithm): + def __init__( + self, delay_value: bool, loss_function: str, my_custom_arg: int, **kwargs + ): + # In the init function you can define the init parameters you need, just make sure + # to pass the kwargs to the super() class + super().__init__(**kwargs) + + self.delay_value = delay_value + self.loss_function = loss_function + self.my_custom_arg = my_custom_arg + + # In all the class you have access to a lot of extra things like + self.my_custom_method() # Custom methods + _ = self.experiment_config # Experiment config + _ = self.model_config # Policy config + _ = self.critic_model_config # Eventual critic config + _ = self.group_map # The group to agent names map + + # Specs + _ = self.observation_spec + _ = self.action_spec + _ = self.state_spec + _ = self.action_mask_spec + + ############################# + # Overridden abstract methods + ############################# + + def _get_loss( + self, group: str, policy_for_loss: TensorDictModule, continuous: bool + ) -> Tuple[LossModule, bool]: + if continuous: + raise NotImplementedError( + "Custom Iql is not compatible with continuous actions." + ) + else: + # Loss + loss_module = DQNLoss( + policy_for_loss, + delay_value=self.delay_value, + loss_function=self.loss_function, + action_space=self.action_spec[group, "action"], + ) + # Always tell the loss where to finc the data + # You can make sure the data is in the right place in self.process_batch + # This loss for example expects all data to have the multagent dimension so we take care of that in + # self.process_batch + loss_module.set_keys( + reward=(group, "reward"), + action=(group, "action"), + done=(group, "done"), + terminated=(group, "terminated"), + action_value=(group, "action_value"), + value=(group, "chosen_action_value"), + priority=(group, "td_error"), + ) + # Choose your value estimator, see what is available in the ValueEstimators enum + loss_module.make_value_estimator( + ValueEstimators.TD0, gamma=self.experiment_config.gamma + ) + # This loss has target delayed parameters so the second value is True + return loss_module, True + + def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]: + # For each loss name, associate it the parameters you want + # You can optionally modify (aggregate) loss names in self.process_loss_vals() + return {"loss": loss.parameters()} + + def _get_policy_for_loss( + self, group: str, model_config: ModelConfig, continuous: bool + ) -> TensorDictModule: + if continuous: + raise ValueError("This should never happen") + + # The number of agents in the group + n_agents = len(self.group_map[group]) + # The shape of the discrete action + logits_shape = [ + *self.action_spec[group, "action"].shape, + self.action_spec[group, "action"].space.n, + ] + + # This is the spec of the policy input for this group + actor_input_spec = CompositeSpec( + { + group: CompositeSpec( + { + "observation": self.observation_spec[group]["observation"] + .clone() + .to(self.device) + }, + shape=(n_agents,), + ) + } + ) + # This is the spec of the policy output for this group + actor_output_spec = CompositeSpec( + { + group: CompositeSpec( + {"action_value": UnboundedContinuousTensorSpec(shape=logits_shape)}, + shape=(n_agents,), + ) + } + ) + # This is our neural policy + actor_module = model_config.get_model( + input_spec=actor_input_spec, + output_spec=actor_output_spec, + agent_group=group, + input_has_agent_dim=True, # Always true for a policy + n_agents=n_agents, + centralised=False, # Always false for a policy + share_params=self.experiment_config.share_policy_params, + device=self.device, + ) + if self.action_mask_spec is not None: + action_mask_key = (group, "action_mask") + else: + action_mask_key = None + + value_module = QValueModule( + action_value_key=(group, "action_value"), + action_mask_key=action_mask_key, + out_keys=[ + (group, "action"), + (group, "action_value"), + (group, "chosen_action_value"), + ], + spec=self.action_spec[group, "action"], + action_space=None, # We already passed the spec + ) + + # Here we chain the actor and the value module to get our policy + return TensorDictSequential(actor_module, value_module) + + def _get_policy_for_collection( + self, policy_for_loss: TensorDictModule, group: str, continuous: bool + ) -> TensorDictModule: + + if self.action_mask_spec is not None: + action_mask_key = (group, "action_mask") + else: + action_mask_key = None + # Add exploration for collection + greedy = EGreedyModule( + annealing_num_steps=self.experiment_config.get_exploration_anneal_frames( + self.on_policy + ), + action_key=(group, "action"), + spec=self.action_spec[(group, "action")], + action_mask_key=action_mask_key, + eps_init=self.experiment_config.exploration_eps_init, + eps_end=self.experiment_config.exploration_eps_end, + ) + return TensorDictSequential(*policy_for_loss, greedy) + + def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: + # Here we make sure that all entries have the desired shape, + # thus, in case there are shared dones, terminated, or rewards, we expande them + + keys = list(batch.keys(True, True)) + group_shape = batch.get(group).shape + + nested_done_key = ("next", group, "done") + nested_terminated_key = ("next", group, "terminated") + nested_reward_key = ("next", group, "reward") + + if nested_done_key not in keys: + batch.set( + nested_done_key, + batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)), + ) + if nested_terminated_key not in keys: + batch.set( + nested_terminated_key, + batch.get(("next", "terminated")) + .unsqueeze(-1) + .expand((*group_shape, 1)), + ) + + if nested_reward_key not in keys: + batch.set( + nested_reward_key, + batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)), + ) + + return batch + + ##################### + # Custom new methods + ##################### + def my_custom_method(self): + pass + + +@dataclass +class CustomIqlConfig(AlgorithmConfig): + # This is a class representing the configuration of your algorithm + # It will be used to validate loaded configs, so that everytime you load this algorithm + # we know exactly which and what parameters to expect with their types + + # This is a list of args passed to your algorithm + delay_value: bool = MISSING + loss_function: str = MISSING + my_custom_arg: int = MISSING + + @staticmethod + def associated_class() -> Type[Algorithm]: + # The associated algorithm class + return CustomIqlAlgorithm + + @staticmethod + def supports_continuous_actions() -> bool: + # Is it compatible with continuous actions? + return False + + @staticmethod + def supports_discrete_actions() -> bool: + # Is it compatible with discrete actions? + return True + + @staticmethod + def on_policy() -> bool: + # Should it be trained on or off policy? + return False + + @staticmethod + def get_from_yaml(path: Optional[str] = None): + if path is None: + # If get_from_yaml is called without a path, + # we load from benchmarl/conf/algorithm/{CustomIqlConfig.associated_class().__name__} + return CustomIqlConfig( + **AlgorithmConfig._load_from_yaml( + name=CustomIqlConfig.associated_class().__name__, + ) + ) + else: + # Otherwise, we load it from the given absolute path + return CustomIqlConfig(**read_yaml_config(path)) diff --git a/examples/extending/algorithm/customiqlalgorithm.yaml b/examples/extending/algorithm/customiqlalgorithm.yaml new file mode 100644 index 00000000..a5a2e12b --- /dev/null +++ b/examples/extending/algorithm/customiqlalgorithm.yaml @@ -0,0 +1,7 @@ +defaults: + - customalgorithm_config + - _self_ + +delay_value: bool = True +loss_function: str = "l2" +my_custom_arg: int = 3 diff --git a/examples/extending/custom_model.py b/examples/extending/model/custom_model.py similarity index 100% rename from examples/extending/custom_model.py rename to examples/extending/model/custom_model.py diff --git a/examples/extending/task/README.md b/examples/extending/task/README.md new file mode 100644 index 00000000..e4ecd9e0 --- /dev/null +++ b/examples/extending/task/README.md @@ -0,0 +1,16 @@ + +# Creating a new task (from a new environemnt) + +Here are the steps to create a new algorithm. You can find the custom IQL algorithm +created for this example in [`custom_agorithm.py`](custom_algorithm.py). + +1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py). +This is an enum with task entries and some abstract functions you need to implement. + +2. Create a `customenv` folder with a yaml configuration file for each of your tasks +You can see [`customenv`](customenv) for an example. +3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and +your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to +override from). +4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the +[`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks. diff --git a/examples/extending/task/custom_task.py b/examples/extending/task/custom_task.py new file mode 100644 index 00000000..860dcee3 --- /dev/null +++ b/examples/extending/task/custom_task.py @@ -0,0 +1,90 @@ +from typing import Callable, Dict, List, Optional + +from benchmarl.environments.common import Task +from benchmarl.utils import DEVICE_TYPING + +from tensordict import TensorDictBase +from torchrl.data import CompositeSpec +from torchrl.envs import EnvBase +from torchrl.envs.libs import YourTorchRLEnvConstructor + + +class CustomEnvTask(Task): + # Your task names. + # Their config will be loaded from benchmarl/conf/task/customenv + + TASK_1 = None # Loaded automatically from benchmarl/conf/task/customenv/task_1 + TASK_2 = None # Loaded automatically from benchmarl/conf/task/customenv/task_2 + + def get_env_fun( + self, + num_envs: int, + continuous_actions: bool, + seed: Optional[int], + device: DEVICE_TYPING, + ) -> Callable[[], EnvBase]: + return lambda: YourTorchRLEnvConstructor( + scenario=self.name.lower(), + num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vecotrized) + continuous_actions=continuous_actions, # Ignore this param if your env does not have this choice + seed=seed, + device=device, + categorical_actions=True, # If your env has discrete actions, they need to be categorical (TorchRL can help with this) + **self.config, # Pass the loaded config (this is what is in your yaml + ) + + def supports_continuous_actions(self) -> bool: + # Does the environment support continuous actions? + return True + + def supports_discrete_actions(self) -> bool: + # Does the environment support discrete actions? + return True + + def has_render(self, env: EnvBase) -> bool: + # Does the env have a env.render(mode="rgb_array") or env.render() function? + return True + + def max_steps(self, env: EnvBase) -> int: + # Maximum number of steps for a rollout during evaluation + return 100 + + def group_map(self, env: EnvBase) -> Dict[str, List[str]]: + # The group map mapping group names to agent names + # The data in the tensordict will havebe presented this way + return {"agents": [agent.name for agent in env.agents]} + + def observation_spec(self, env: EnvBase) -> CompositeSpec: + # A spec for the observation. + # Must be a CompositeSpec with one (group_name, "observation") entry per group. + return env.full_observation_spec + + def action_spec(self, env: EnvBase) -> CompositeSpec: + # A spec for the observation. + # If provided, must be a CompositeSpec with one (group_name, "action") entry per group. + return env.full_action_spec + + def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]: + # A spec for the state. + # If provided, must be a CompositeSpec with one "state" entry + return None + + def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]: + # A spec for the state. + # If provided, must be a CompositeSpec with one (group_name, "action_mask") entry per group. + return None + + def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: + # A spec for the observation. + # If provided, must be a CompositeSpec with one (group_name, "info") entry per group (this entry can be composite). + return None + + @staticmethod + def env_name() -> str: + # The name of the environment in the benchmarl/conf/task folder + return "customenv" + + def log_info(self, batch: TensorDictBase) -> Dict[str, float]: + # Optionally return a str->float dict with extra things to log + # This function has access to the collected batch and its optional + return {}