-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Matteo Bettini <[email protected]>
- Loading branch information
1 parent
8bec460
commit 731a3ff
Showing
6 changed files
with
385 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
|
||
# Creating a new algorithm | ||
|
||
Here are the steps to create a new algorithm. You can find the custom IQL algorithm | ||
created for this example in [`custom_agorithm.py`](custom_algorithm.py). | ||
|
||
1. Create your `CustomAlgorithm` and `CustomAlgorithmConfig` following the example | ||
in [`custom_agorithm.py`](custom_algorithm.py). These will be the algorithm code | ||
and a associated dataclass to validate loaded configs. | ||
2. Create a `customalgorithm.yaml` with the configuration parameters you defined | ||
in your script. Make sure it has `customalgorithm_config` within its defaults at | ||
the top of the file to let hydra know which python dataclass it is | ||
associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml) | ||
for an example | ||
3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and | ||
your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to | ||
override from) | ||
4. Add `{"custom_agorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithm.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
from dataclasses import dataclass, MISSING | ||
from typing import Dict, Iterable, Optional, Tuple, Type | ||
|
||
from benchmarl.algorithms.common import Algorithm, AlgorithmConfig | ||
from benchmarl.models.common import ModelConfig | ||
from benchmarl.utils import read_yaml_config | ||
|
||
from tensordict import TensorDictBase | ||
from tensordict.nn import TensorDictModule, TensorDictSequential | ||
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec | ||
from torchrl.modules import EGreedyModule, QValueModule | ||
from torchrl.objectives import DQNLoss, LossModule, ValueEstimators | ||
|
||
|
||
class CustomIqlAlgorithm(Algorithm): | ||
def __init__( | ||
self, delay_value: bool, loss_function: str, my_custom_arg: int, **kwargs | ||
): | ||
# In the init function you can define the init parameters you need, just make sure | ||
# to pass the kwargs to the super() class | ||
super().__init__(**kwargs) | ||
|
||
self.delay_value = delay_value | ||
self.loss_function = loss_function | ||
self.my_custom_arg = my_custom_arg | ||
|
||
# In all the class you have access to a lot of extra things like | ||
self.my_custom_method() # Custom methods | ||
_ = self.experiment_config # Experiment config | ||
_ = self.model_config # Policy config | ||
_ = self.critic_model_config # Eventual critic config | ||
_ = self.group_map # The group to agent names map | ||
|
||
# Specs | ||
_ = self.observation_spec | ||
_ = self.action_spec | ||
_ = self.state_spec | ||
_ = self.action_mask_spec | ||
|
||
############################# | ||
# Overridden abstract methods | ||
############################# | ||
|
||
def _get_loss( | ||
self, group: str, policy_for_loss: TensorDictModule, continuous: bool | ||
) -> Tuple[LossModule, bool]: | ||
if continuous: | ||
raise NotImplementedError( | ||
"Custom Iql is not compatible with continuous actions." | ||
) | ||
else: | ||
# Loss | ||
loss_module = DQNLoss( | ||
policy_for_loss, | ||
delay_value=self.delay_value, | ||
loss_function=self.loss_function, | ||
action_space=self.action_spec[group, "action"], | ||
) | ||
# Always tell the loss where to finc the data | ||
# You can make sure the data is in the right place in self.process_batch | ||
# This loss for example expects all data to have the multagent dimension so we take care of that in | ||
# self.process_batch | ||
loss_module.set_keys( | ||
reward=(group, "reward"), | ||
action=(group, "action"), | ||
done=(group, "done"), | ||
terminated=(group, "terminated"), | ||
action_value=(group, "action_value"), | ||
value=(group, "chosen_action_value"), | ||
priority=(group, "td_error"), | ||
) | ||
# Choose your value estimator, see what is available in the ValueEstimators enum | ||
loss_module.make_value_estimator( | ||
ValueEstimators.TD0, gamma=self.experiment_config.gamma | ||
) | ||
# This loss has target delayed parameters so the second value is True | ||
return loss_module, True | ||
|
||
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]: | ||
# For each loss name, associate it the parameters you want | ||
# You can optionally modify (aggregate) loss names in self.process_loss_vals() | ||
return {"loss": loss.parameters()} | ||
|
||
def _get_policy_for_loss( | ||
self, group: str, model_config: ModelConfig, continuous: bool | ||
) -> TensorDictModule: | ||
if continuous: | ||
raise ValueError("This should never happen") | ||
|
||
# The number of agents in the group | ||
n_agents = len(self.group_map[group]) | ||
# The shape of the discrete action | ||
logits_shape = [ | ||
*self.action_spec[group, "action"].shape, | ||
self.action_spec[group, "action"].space.n, | ||
] | ||
|
||
# This is the spec of the policy input for this group | ||
actor_input_spec = CompositeSpec( | ||
{ | ||
group: CompositeSpec( | ||
{ | ||
"observation": self.observation_spec[group]["observation"] | ||
.clone() | ||
.to(self.device) | ||
}, | ||
shape=(n_agents,), | ||
) | ||
} | ||
) | ||
# This is the spec of the policy output for this group | ||
actor_output_spec = CompositeSpec( | ||
{ | ||
group: CompositeSpec( | ||
{"action_value": UnboundedContinuousTensorSpec(shape=logits_shape)}, | ||
shape=(n_agents,), | ||
) | ||
} | ||
) | ||
# This is our neural policy | ||
actor_module = model_config.get_model( | ||
input_spec=actor_input_spec, | ||
output_spec=actor_output_spec, | ||
agent_group=group, | ||
input_has_agent_dim=True, # Always true for a policy | ||
n_agents=n_agents, | ||
centralised=False, # Always false for a policy | ||
share_params=self.experiment_config.share_policy_params, | ||
device=self.device, | ||
) | ||
if self.action_mask_spec is not None: | ||
action_mask_key = (group, "action_mask") | ||
else: | ||
action_mask_key = None | ||
|
||
value_module = QValueModule( | ||
action_value_key=(group, "action_value"), | ||
action_mask_key=action_mask_key, | ||
out_keys=[ | ||
(group, "action"), | ||
(group, "action_value"), | ||
(group, "chosen_action_value"), | ||
], | ||
spec=self.action_spec[group, "action"], | ||
action_space=None, # We already passed the spec | ||
) | ||
|
||
# Here we chain the actor and the value module to get our policy | ||
return TensorDictSequential(actor_module, value_module) | ||
|
||
def _get_policy_for_collection( | ||
self, policy_for_loss: TensorDictModule, group: str, continuous: bool | ||
) -> TensorDictModule: | ||
|
||
if self.action_mask_spec is not None: | ||
action_mask_key = (group, "action_mask") | ||
else: | ||
action_mask_key = None | ||
# Add exploration for collection | ||
greedy = EGreedyModule( | ||
annealing_num_steps=self.experiment_config.get_exploration_anneal_frames( | ||
self.on_policy | ||
), | ||
action_key=(group, "action"), | ||
spec=self.action_spec[(group, "action")], | ||
action_mask_key=action_mask_key, | ||
eps_init=self.experiment_config.exploration_eps_init, | ||
eps_end=self.experiment_config.exploration_eps_end, | ||
) | ||
return TensorDictSequential(*policy_for_loss, greedy) | ||
|
||
def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: | ||
# Here we make sure that all entries have the desired shape, | ||
# thus, in case there are shared dones, terminated, or rewards, we expande them | ||
|
||
keys = list(batch.keys(True, True)) | ||
group_shape = batch.get(group).shape | ||
|
||
nested_done_key = ("next", group, "done") | ||
nested_terminated_key = ("next", group, "terminated") | ||
nested_reward_key = ("next", group, "reward") | ||
|
||
if nested_done_key not in keys: | ||
batch.set( | ||
nested_done_key, | ||
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)), | ||
) | ||
if nested_terminated_key not in keys: | ||
batch.set( | ||
nested_terminated_key, | ||
batch.get(("next", "terminated")) | ||
.unsqueeze(-1) | ||
.expand((*group_shape, 1)), | ||
) | ||
|
||
if nested_reward_key not in keys: | ||
batch.set( | ||
nested_reward_key, | ||
batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)), | ||
) | ||
|
||
return batch | ||
|
||
##################### | ||
# Custom new methods | ||
##################### | ||
def my_custom_method(self): | ||
pass | ||
|
||
|
||
@dataclass | ||
class CustomIqlConfig(AlgorithmConfig): | ||
# This is a class representing the configuration of your algorithm | ||
# It will be used to validate loaded configs, so that everytime you load this algorithm | ||
# we know exactly which and what parameters to expect with their types | ||
|
||
# This is a list of args passed to your algorithm | ||
delay_value: bool = MISSING | ||
loss_function: str = MISSING | ||
my_custom_arg: int = MISSING | ||
|
||
@staticmethod | ||
def associated_class() -> Type[Algorithm]: | ||
# The associated algorithm class | ||
return CustomIqlAlgorithm | ||
|
||
@staticmethod | ||
def supports_continuous_actions() -> bool: | ||
# Is it compatible with continuous actions? | ||
return False | ||
|
||
@staticmethod | ||
def supports_discrete_actions() -> bool: | ||
# Is it compatible with discrete actions? | ||
return True | ||
|
||
@staticmethod | ||
def on_policy() -> bool: | ||
# Should it be trained on or off policy? | ||
return False | ||
|
||
@staticmethod | ||
def get_from_yaml(path: Optional[str] = None): | ||
if path is None: | ||
# If get_from_yaml is called without a path, | ||
# we load from benchmarl/conf/algorithm/{CustomIqlConfig.associated_class().__name__} | ||
return CustomIqlConfig( | ||
**AlgorithmConfig._load_from_yaml( | ||
name=CustomIqlConfig.associated_class().__name__, | ||
) | ||
) | ||
else: | ||
# Otherwise, we load it from the given absolute path | ||
return CustomIqlConfig(**read_yaml_config(path)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
defaults: | ||
- customalgorithm_config | ||
- _self_ | ||
|
||
delay_value: bool = True | ||
loss_function: str = "l2" | ||
my_custom_arg: int = 3 |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
|
||
# Creating a new task (from a new environemnt) | ||
|
||
Here are the steps to create a new algorithm. You can find the custom IQL algorithm | ||
created for this example in [`custom_agorithm.py`](custom_algorithm.py). | ||
|
||
1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py). | ||
This is an enum with task entries and some abstract functions you need to implement. | ||
|
||
2. Create a `customenv` folder with a yaml configuration file for each of your tasks | ||
You can see [`customenv`](customenv) for an example. | ||
3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and | ||
your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to | ||
override from). | ||
4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the | ||
[`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from typing import Callable, Dict, List, Optional | ||
|
||
from benchmarl.environments.common import Task | ||
from benchmarl.utils import DEVICE_TYPING | ||
|
||
from tensordict import TensorDictBase | ||
from torchrl.data import CompositeSpec | ||
from torchrl.envs import EnvBase | ||
from torchrl.envs.libs import YourTorchRLEnvConstructor | ||
|
||
|
||
class CustomEnvTask(Task): | ||
# Your task names. | ||
# Their config will be loaded from benchmarl/conf/task/customenv | ||
|
||
TASK_1 = None # Loaded automatically from benchmarl/conf/task/customenv/task_1 | ||
TASK_2 = None # Loaded automatically from benchmarl/conf/task/customenv/task_2 | ||
|
||
def get_env_fun( | ||
self, | ||
num_envs: int, | ||
continuous_actions: bool, | ||
seed: Optional[int], | ||
device: DEVICE_TYPING, | ||
) -> Callable[[], EnvBase]: | ||
return lambda: YourTorchRLEnvConstructor( | ||
scenario=self.name.lower(), | ||
num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vecotrized) | ||
continuous_actions=continuous_actions, # Ignore this param if your env does not have this choice | ||
seed=seed, | ||
device=device, | ||
categorical_actions=True, # If your env has discrete actions, they need to be categorical (TorchRL can help with this) | ||
**self.config, # Pass the loaded config (this is what is in your yaml | ||
) | ||
|
||
def supports_continuous_actions(self) -> bool: | ||
# Does the environment support continuous actions? | ||
return True | ||
|
||
def supports_discrete_actions(self) -> bool: | ||
# Does the environment support discrete actions? | ||
return True | ||
|
||
def has_render(self, env: EnvBase) -> bool: | ||
# Does the env have a env.render(mode="rgb_array") or env.render() function? | ||
return True | ||
|
||
def max_steps(self, env: EnvBase) -> int: | ||
# Maximum number of steps for a rollout during evaluation | ||
return 100 | ||
|
||
def group_map(self, env: EnvBase) -> Dict[str, List[str]]: | ||
# The group map mapping group names to agent names | ||
# The data in the tensordict will havebe presented this way | ||
return {"agents": [agent.name for agent in env.agents]} | ||
|
||
def observation_spec(self, env: EnvBase) -> CompositeSpec: | ||
# A spec for the observation. | ||
# Must be a CompositeSpec with one (group_name, "observation") entry per group. | ||
return env.full_observation_spec | ||
|
||
def action_spec(self, env: EnvBase) -> CompositeSpec: | ||
# A spec for the observation. | ||
# If provided, must be a CompositeSpec with one (group_name, "action") entry per group. | ||
return env.full_action_spec | ||
|
||
def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]: | ||
# A spec for the state. | ||
# If provided, must be a CompositeSpec with one "state" entry | ||
return None | ||
|
||
def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]: | ||
# A spec for the state. | ||
# If provided, must be a CompositeSpec with one (group_name, "action_mask") entry per group. | ||
return None | ||
|
||
def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: | ||
# A spec for the observation. | ||
# If provided, must be a CompositeSpec with one (group_name, "info") entry per group (this entry can be composite). | ||
return None | ||
|
||
@staticmethod | ||
def env_name() -> str: | ||
# The name of the environment in the benchmarl/conf/task folder | ||
return "customenv" | ||
|
||
def log_info(self, batch: TensorDictBase) -> Dict[str, float]: | ||
# Optionally return a str->float dict with extra things to log | ||
# This function has access to the collected batch and its optional | ||
return {} |