diff --git a/examples/extending/algorithm/README.md b/examples/extending/algorithm/README.md index 3cf4323a..64b99df6 100644 --- a/examples/extending/algorithm/README.md +++ b/examples/extending/algorithm/README.md @@ -6,13 +6,17 @@ created for this example in [`custom_agorithm.py`](custom_algorithm.py). 1. Create your `CustomAlgorithm` and `CustomAlgorithmConfig` following the example in [`custom_agorithm.py`](custom_algorithm.py). These will be the algorithm code -and a associated dataclass to validate loaded configs. +and an associated dataclass to validate loaded configs. 2. Create a `customalgorithm.yaml` with the configuration parameters you defined in your script. Make sure it has `customalgorithm_config` within its defaults at the top of the file to let hydra know which python dataclass it is associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml) -for an example +for an example. 3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to override from) 4. Add `{"custom_agorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithm.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py) +5. Load it with +```bash +python benchmarl/run.py algorithm=customalgorithm task=... +``` diff --git a/examples/extending/algorithm/custom_algorithm.py b/examples/extending/algorithm/custom_algorithm.py index 79ca4bb8..baaeb0de 100644 --- a/examples/extending/algorithm/custom_algorithm.py +++ b/examples/extending/algorithm/custom_algorithm.py @@ -201,6 +201,13 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: return batch + def process_loss_vals( + self, group: str, loss_vals: TensorDictBase + ) -> TensorDictBase: + # Here you can modify the loss_vals tensordict containing entries loss_name->loss_value + # For example you can sum two entries in a new entry to optimize them together. + return loss_vals + ##################### # Custom new methods ##################### diff --git a/examples/extending/algorithm/customiqlalgorithm.yaml b/examples/extending/algorithm/customiqlalgorithm.yaml index a5a2e12b..e8f3b864 100644 --- a/examples/extending/algorithm/customiqlalgorithm.yaml +++ b/examples/extending/algorithm/customiqlalgorithm.yaml @@ -2,6 +2,6 @@ defaults: - customalgorithm_config - _self_ -delay_value: bool = True -loss_function: str = "l2" -my_custom_arg: int = 3 +delay_value: True +loss_function: "l2" +my_custom_arg: 3 diff --git a/examples/extending/model/README.md b/examples/extending/model/README.md new file mode 100644 index 00000000..c087aa29 --- /dev/null +++ b/examples/extending/model/README.md @@ -0,0 +1,20 @@ + +# Creating a new model + +Here are the steps to create a new model. + +1. Create your `CustomModel` and `CustomModelConfig` following the example +in [`custom_model.py`](custom_model.py). These will be the model code +and an associated dataclass to validate loaded configs. +2. Create a `custommodel.yaml` with the configuration parameters you defined +in your script. Make sure it has a `name` entry equal to `custom_model` to let hydra know which python dataclass it is +associated to. You can see [`custommodel.yaml`](custommodel.yaml) +for an example. +3. Place your model script in [`benchmarl/models`](../../../benchmarl/models) and +your config in [`benchmarl/conf/model/layers`](../../../benchmarl/conf/model/layers) (or any other place you want to +override from) +4. Add `{"custom_model": CustomModelConfig}` to the [`benchmarl.models.model_config_registry`](../../../benchmarl/models/__init__.py) +5. Load it with +```bash +python benchmarl/run.py model=layers/custommodel algorithm=... task=... +``` diff --git a/examples/extending/model/custom_model.py b/examples/extending/model/custom_model.py index e69de29b..f4cc3172 100644 --- a/examples/extending/model/custom_model.py +++ b/examples/extending/model/custom_model.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from dataclasses import dataclass, MISSING +from typing import Optional, Type + +import torch + +from benchmarl.models.common import Model, ModelConfig, parse_model_config +from benchmarl.utils import read_yaml_config +from tensordict import TensorDictBase +from torch import nn +from torchrl.modules import MLP, MultiAgentMLP + + +class CustomModel(Model): + def __init__( + self, + custom_param: int, + activation_function: Type[nn.Module], + **kwargs, + ): + # Models in BenchMARL are instantiated per agent group. + # This means that each model will process the inputs for a whole group of agents + # There are some core attributes that models are created with, + # which we are now going to illustrate + + # Remember the kwargs to the super() class + super().__init__(**kwargs) + + # You can create your custom attributes + self.custom_param = custom_param + self.activation_function = activation_function + + # And access some of the ones already available to your module + _ = self.input_spec # Like its input_spec + _ = self.output_spec # or output_spec + _ = self.agent_group # the name of the agent group the model is for + _ = self.n_agents # or the number of agents this module is for + + # The following are some of the most important attributes of the model. + # They decide how the model should be used. + # Since models can be used for actors and critics, the model will behave differently + # depending on how these attributes are set. + # BenchMARL will take care of setting these attributes for you, but your role when making + # a custom model is making sure that all cases are handled properly + + # This tells the model if the input will have a multi-agent dimension or not. + # For example, the input of policies will always have this set to true, + # but critics that use a global state have this set to false as the state + # is shared by all agents + _ = self.input_has_agent_dim + + # This tells the model if it should have only one set of parameters + # or a different set of parameters for each agent. + # This is independent of the other options as it is possible to have different parameters + # for centralized critics with global input + _ = self.share_params + + # This tells the model if it has full observability + # This will always be true when self.input_has_agent_dim==False + # but in cases where the input has the agent dimension, this parameter is + # used to distinguish between a decentralised model (where each agent's data + # is processed separately) and a centralized model, where the model pools all data together + _ = self.centralised + + # This is a dynamically computed attribute that indicates if the output will have the agent dimension. + # This will be false when share_params==True and centralised==True, and true in all other cases. + # When output_has_agent_dim is true, your model's output should contain the multiagent dimension, + # and the dimension should be absent otherwise + _ = self.output_has_agent_dim + + self.input_features = self.input_leaf_spec.shape[-1] + self.output_features = self.output_leaf_spec.shape[-1] + + if self.input_has_agent_dim and not self.centralised: + # Instantiate a model for this scenario + # This code will be executed for a policy or for a decentralized critic for example + self.mlp = MultiAgentMLP( + n_agent_inputs=self.input_features, + n_agent_outputs=self.output_features, + n_agents=self.n_agents, + centralised=self.centralised, + share_params=self.share_params, + device=self.device, + activation_function=self.activation_function, + ) + elif self.input_has_agent_dim and self.centralised: + # Instantiate a model for this scenario + # This code will be executed for a centralized critic that takes local inputs + self.mlp = MultiAgentMLP( + n_agent_inputs=self.input_features, + n_agent_outputs=self.output_features, + n_agents=self.n_agents, + centralised=self.centralised, + share_params=self.share_params, + device=self.device, + activation_function=self.activation_function, + ) + else: + # Instantiate a model for this scenario + # This code will be executed for a centralized critic that takes global inputs + self.mlp = nn.ModuleList( + [ + MLP( + in_features=self.input_features, + out_features=self.output_features, + device=self.device, + activation_function=self.activation_function, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + def _perform_checks(self): + super()._perform_checks() + + # Run some checks + if self.input_has_agent_dim and self.input_leaf_spec.shape[-2] != self.n_agents: + raise ValueError( + "If the MLP input has the agent dimension," + " the second to last spec dimension should be the number of agents" + ) + if ( + self.output_has_agent_dim + and self.output_leaf_spec.shape[-2] != self.n_agents + ): + raise ValueError( + "If the MLP output has the agent dimension," + " the second to last spec dimension should be the number of agents" + ) + + def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: + # Gather in_key + input = tensordict.get(self.in_key) + + # Input has multi-agent input dimension + if self.input_has_agent_dim: + res = self.mlp.forward(input) + if not self.output_has_agent_dim: + # If we are here the module is centralised and parameter shared. + # Thus the multi-agent dimension has been expanded, + # We remove it without loss of data + res = res[..., 0, :] + + # Input does not have multi-agent input dimension + else: + if not self.share_params: + res = torch.stack( + [net(input) for net in self.mlp], + dim=-2, + ) + else: + res = self.mlp[0](input) + + tensordict.set(self.out_key, res) + return tensordict + + +@dataclass +class CustomModelConfig(ModelConfig): + # The config parameters for this class, these will be loaded from yaml + custom_param: int = MISSING + activation_function: Type[nn.Module] = MISSING + + @staticmethod + def associated_class(): + # The associated algorithm class + return CustomModel + + @staticmethod + def get_from_yaml(path: Optional[str] = None) -> CustomModelConfig: + if path is None: + # If get_from_yaml is called without a path, + # we load from benchmarl/conf/algorithm/{CustomModelConfig.associated_class().__name__} + return CustomModelConfig( + **ModelConfig._load_from_yaml( + name=CustomModelConfig.associated_class().__name__, + ) + ) + else: + # Otherwise, we load it from the given absolute path + return CustomModelConfig(**parse_model_config(read_yaml_config(path))) diff --git a/examples/extending/model/custommodel.yaml b/examples/extending/model/custommodel.yaml new file mode 100644 index 00000000..ed641488 --- /dev/null +++ b/examples/extending/model/custommodel.yaml @@ -0,0 +1,4 @@ + +name: custom_model +custom_param: 3 +activation_function: torch.nn.Tanh diff --git a/examples/extending/task/README.md b/examples/extending/task/README.md index e4ecd9e0..81c82beb 100644 --- a/examples/extending/task/README.md +++ b/examples/extending/task/README.md @@ -1,16 +1,23 @@ # Creating a new task (from a new environemnt) -Here are the steps to create a new algorithm. You can find the custom IQL algorithm -created for this example in [`custom_agorithm.py`](custom_algorithm.py). +Here are the steps to create a new task. 1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py). This is an enum with task entries and some abstract functions you need to implement. -2. Create a `customenv` folder with a yaml configuration file for each of your tasks +2. Create a `customenv` folder with a yaml configuration file for each of your tasks. You can see [`customenv`](customenv) for an example. 3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to override from). 4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the [`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks. +5. Load it with +```bash +python benchmarl/run.py task=customenv/task_name algorithm=... +``` + +6. (Optional) You can create python dataclasses to use as schemas for your tasks +to validate their config. We are not going to illustrate this here, but if +you want to see an example, check out [`benchmarl/environments/vmas`](../../../benchmarl/environments/vmas). diff --git a/examples/extending/task/custom_task.py b/examples/extending/task/custom_task.py index 860dcee3..91e57fd4 100644 --- a/examples/extending/task/custom_task.py +++ b/examples/extending/task/custom_task.py @@ -25,7 +25,7 @@ def get_env_fun( ) -> Callable[[], EnvBase]: return lambda: YourTorchRLEnvConstructor( scenario=self.name.lower(), - num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vecotrized) + num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vectorized) continuous_actions=continuous_actions, # Ignore this param if your env does not have this choice seed=seed, device=device,