Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Examples about extending #22

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/extending/algorithm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ created for this example in [`custom_agorithm.py`](custom_algorithm.py).

1. Create your `CustomAlgorithm` and `CustomAlgorithmConfig` following the example
in [`custom_agorithm.py`](custom_algorithm.py). These will be the algorithm code
and a associated dataclass to validate loaded configs.
and an associated dataclass to validate loaded configs.
2. Create a `customalgorithm.yaml` with the configuration parameters you defined
in your script. Make sure it has `customalgorithm_config` within its defaults at
the top of the file to let hydra know which python dataclass it is
associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml)
for an example
for an example.
3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and
your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to
override from)
4. Add `{"custom_agorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithm.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py)
5. Load it with
```bash
python benchmarl/run.py algorithm=customalgorithm task=...
```
7 changes: 7 additions & 0 deletions examples/extending/algorithm/custom_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:

return batch

def process_loss_vals(
self, group: str, loss_vals: TensorDictBase
) -> TensorDictBase:
# Here you can modify the loss_vals tensordict containing entries loss_name->loss_value
# For example you can sum two entries in a new entry to optimize them together.
return loss_vals

#####################
# Custom new methods
#####################
Expand Down
6 changes: 3 additions & 3 deletions examples/extending/algorithm/customiqlalgorithm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ defaults:
- customalgorithm_config
- _self_

delay_value: bool = True
loss_function: str = "l2"
my_custom_arg: int = 3
delay_value: True
loss_function: "l2"
my_custom_arg: 3
20 changes: 20 additions & 0 deletions examples/extending/model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

# Creating a new model

Here are the steps to create a new model.

1. Create your `CustomModel` and `CustomModelConfig` following the example
in [`custom_model.py`](custom_model.py). These will be the model code
and an associated dataclass to validate loaded configs.
2. Create a `custommodel.yaml` with the configuration parameters you defined
in your script. Make sure it has a `name` entry equal to `custom_model` to let hydra know which python dataclass it is
associated to. You can see [`custommodel.yaml`](custommodel.yaml)
for an example.
3. Place your model script in [`benchmarl/models`](../../../benchmarl/models) and
your config in [`benchmarl/conf/model/layers`](../../../benchmarl/conf/model/layers) (or any other place you want to
override from)
4. Add `{"custom_model": CustomModelConfig}` to the [`benchmarl.models.model_config_registry`](../../../benchmarl/models/__init__.py)
5. Load it with
```bash
python benchmarl/run.py model=layers/custommodel algorithm=... task=...
```
182 changes: 182 additions & 0 deletions examples/extending/model/custom_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations

from dataclasses import dataclass, MISSING
from typing import Optional, Type

import torch

from benchmarl.models.common import Model, ModelConfig, parse_model_config
from benchmarl.utils import read_yaml_config
from tensordict import TensorDictBase
from torch import nn
from torchrl.modules import MLP, MultiAgentMLP


class CustomModel(Model):
def __init__(
self,
custom_param: int,
activation_function: Type[nn.Module],
**kwargs,
):
# Models in BenchMARL are instantiated per agent group.
# This means that each model will process the inputs for a whole group of agents
# There are some core attributes that models are created with,
# which we are now going to illustrate

# Remember the kwargs to the super() class
super().__init__(**kwargs)

# You can create your custom attributes
self.custom_param = custom_param
self.activation_function = activation_function

# And access some of the ones already available to your module
_ = self.input_spec # Like its input_spec
_ = self.output_spec # or output_spec
_ = self.agent_group # the name of the agent group the model is for
_ = self.n_agents # or the number of agents this module is for

# The following are some of the most important attributes of the model.
# They decide how the model should be used.
# Since models can be used for actors and critics, the model will behave differently
# depending on how these attributes are set.
# BenchMARL will take care of setting these attributes for you, but your role when making
# a custom model is making sure that all cases are handled properly

# This tells the model if the input will have a multi-agent dimension or not.
# For example, the input of policies will always have this set to true,
# but critics that use a global state have this set to false as the state
# is shared by all agents
_ = self.input_has_agent_dim

# This tells the model if it should have only one set of parameters
# or a different set of parameters for each agent.
# This is independent of the other options as it is possible to have different parameters
# for centralized critics with global input
_ = self.share_params

# This tells the model if it has full observability
# This will always be true when self.input_has_agent_dim==False
# but in cases where the input has the agent dimension, this parameter is
# used to distinguish between a decentralised model (where each agent's data
# is processed separately) and a centralized model, where the model pools all data together
_ = self.centralised

# This is a dynamically computed attribute that indicates if the output will have the agent dimension.
# This will be false when share_params==True and centralised==True, and true in all other cases.
# When output_has_agent_dim is true, your model's output should contain the multiagent dimension,
# and the dimension should be absent otherwise
_ = self.output_has_agent_dim

self.input_features = self.input_leaf_spec.shape[-1]
self.output_features = self.output_leaf_spec.shape[-1]

if self.input_has_agent_dim and not self.centralised:
# Instantiate a model for this scenario
# This code will be executed for a policy or for a decentralized critic for example
self.mlp = MultiAgentMLP(
n_agent_inputs=self.input_features,
n_agent_outputs=self.output_features,
n_agents=self.n_agents,
centralised=self.centralised,
share_params=self.share_params,
device=self.device,
activation_function=self.activation_function,
)
elif self.input_has_agent_dim and self.centralised:
# Instantiate a model for this scenario
# This code will be executed for a centralized critic that takes local inputs
self.mlp = MultiAgentMLP(
n_agent_inputs=self.input_features,
n_agent_outputs=self.output_features,
n_agents=self.n_agents,
centralised=self.centralised,
share_params=self.share_params,
device=self.device,
activation_function=self.activation_function,
)
else:
# Instantiate a model for this scenario
# This code will be executed for a centralized critic that takes global inputs
self.mlp = nn.ModuleList(
[
MLP(
in_features=self.input_features,
out_features=self.output_features,
device=self.device,
activation_function=self.activation_function,
)
for _ in range(self.n_agents if not self.share_params else 1)
]
)

def _perform_checks(self):
super()._perform_checks()

# Run some checks
if self.input_has_agent_dim and self.input_leaf_spec.shape[-2] != self.n_agents:
raise ValueError(
"If the MLP input has the agent dimension,"
" the second to last spec dimension should be the number of agents"
)
if (
self.output_has_agent_dim
and self.output_leaf_spec.shape[-2] != self.n_agents
):
raise ValueError(
"If the MLP output has the agent dimension,"
" the second to last spec dimension should be the number of agents"
)

def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather in_key
input = tensordict.get(self.in_key)

# Input has multi-agent input dimension
if self.input_has_agent_dim:
res = self.mlp.forward(input)
if not self.output_has_agent_dim:
# If we are here the module is centralised and parameter shared.
# Thus the multi-agent dimension has been expanded,
# We remove it without loss of data
res = res[..., 0, :]

# Input does not have multi-agent input dimension
else:
if not self.share_params:
res = torch.stack(
[net(input) for net in self.mlp],
dim=-2,
)
else:
res = self.mlp[0](input)

tensordict.set(self.out_key, res)
return tensordict


@dataclass
class CustomModelConfig(ModelConfig):
# The config parameters for this class, these will be loaded from yaml
custom_param: int = MISSING
activation_function: Type[nn.Module] = MISSING

@staticmethod
def associated_class():
# The associated algorithm class
return CustomModel

@staticmethod
def get_from_yaml(path: Optional[str] = None) -> CustomModelConfig:
if path is None:
# If get_from_yaml is called without a path,
# we load from benchmarl/conf/algorithm/{CustomModelConfig.associated_class().__name__}
return CustomModelConfig(
**ModelConfig._load_from_yaml(
name=CustomModelConfig.associated_class().__name__,
)
)
else:
# Otherwise, we load it from the given absolute path
return CustomModelConfig(**parse_model_config(read_yaml_config(path)))
4 changes: 4 additions & 0 deletions examples/extending/model/custommodel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

name: custom_model
custom_param: 3
activation_function: torch.nn.Tanh
13 changes: 10 additions & 3 deletions examples/extending/task/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@

# Creating a new task (from a new environemnt)

Here are the steps to create a new algorithm. You can find the custom IQL algorithm
created for this example in [`custom_agorithm.py`](custom_algorithm.py).
Here are the steps to create a new task.

1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py).
This is an enum with task entries and some abstract functions you need to implement.

2. Create a `customenv` folder with a yaml configuration file for each of your tasks
2. Create a `customenv` folder with a yaml configuration file for each of your tasks.
You can see [`customenv`](customenv) for an example.
3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and
your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to
override from).
4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the
[`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks.
5. Load it with
```bash
python benchmarl/run.py task=customenv/task_name algorithm=...
```

6. (Optional) You can create python dataclasses to use as schemas for your tasks
to validate their config. We are not going to illustrate this here, but if
you want to see an example, check out [`benchmarl/environments/vmas`](../../../benchmarl/environments/vmas).
2 changes: 1 addition & 1 deletion examples/extending/task/custom_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_env_fun(
) -> Callable[[], EnvBase]:
return lambda: YourTorchRLEnvConstructor(
scenario=self.name.lower(),
num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vecotrized)
num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vectorized)
continuous_actions=continuous_actions, # Ignore this param if your env does not have this choice
seed=seed,
device=device,
Expand Down