Skip to content

Commit

Permalink
[Example] Custom task and algorithm
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 8, 2023
1 parent 8bec460 commit 731a3ff
Show file tree
Hide file tree
Showing 6 changed files with 385 additions and 0 deletions.
18 changes: 18 additions & 0 deletions examples/extending/algorithm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# Creating a new algorithm

Here are the steps to create a new algorithm. You can find the custom IQL algorithm
created for this example in [`custom_agorithm.py`](custom_algorithm.py).

1. Create your `CustomAlgorithm` and `CustomAlgorithmConfig` following the example
in [`custom_agorithm.py`](custom_algorithm.py). These will be the algorithm code
and a associated dataclass to validate loaded configs.
2. Create a `customalgorithm.yaml` with the configuration parameters you defined
in your script. Make sure it has `customalgorithm_config` within its defaults at
the top of the file to let hydra know which python dataclass it is
associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml)
for an example
3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and
your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to
override from)
4. Add `{"custom_agorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithm.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py)
254 changes: 254 additions & 0 deletions examples/extending/algorithm/custom_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from dataclasses import dataclass, MISSING
from typing import Dict, Iterable, Optional, Tuple, Type

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import read_yaml_config

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.modules import EGreedyModule, QValueModule
from torchrl.objectives import DQNLoss, LossModule, ValueEstimators


class CustomIqlAlgorithm(Algorithm):
def __init__(
self, delay_value: bool, loss_function: str, my_custom_arg: int, **kwargs
):
# In the init function you can define the init parameters you need, just make sure
# to pass the kwargs to the super() class
super().__init__(**kwargs)

self.delay_value = delay_value
self.loss_function = loss_function
self.my_custom_arg = my_custom_arg

# In all the class you have access to a lot of extra things like
self.my_custom_method() # Custom methods
_ = self.experiment_config # Experiment config
_ = self.model_config # Policy config
_ = self.critic_model_config # Eventual critic config
_ = self.group_map # The group to agent names map

# Specs
_ = self.observation_spec
_ = self.action_spec
_ = self.state_spec
_ = self.action_mask_spec

#############################
# Overridden abstract methods
#############################

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
if continuous:
raise NotImplementedError(
"Custom Iql is not compatible with continuous actions."
)
else:
# Loss
loss_module = DQNLoss(
policy_for_loss,
delay_value=self.delay_value,
loss_function=self.loss_function,
action_space=self.action_spec[group, "action"],
)
# Always tell the loss where to finc the data
# You can make sure the data is in the right place in self.process_batch
# This loss for example expects all data to have the multagent dimension so we take care of that in
# self.process_batch
loss_module.set_keys(
reward=(group, "reward"),
action=(group, "action"),
done=(group, "done"),
terminated=(group, "terminated"),
action_value=(group, "action_value"),
value=(group, "chosen_action_value"),
priority=(group, "td_error"),
)
# Choose your value estimator, see what is available in the ValueEstimators enum
loss_module.make_value_estimator(
ValueEstimators.TD0, gamma=self.experiment_config.gamma
)
# This loss has target delayed parameters so the second value is True
return loss_module, True

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
# For each loss name, associate it the parameters you want
# You can optionally modify (aggregate) loss names in self.process_loss_vals()
return {"loss": loss.parameters()}

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:
if continuous:
raise ValueError("This should never happen")

# The number of agents in the group
n_agents = len(self.group_map[group])
# The shape of the discrete action
logits_shape = [
*self.action_spec[group, "action"].shape,
self.action_spec[group, "action"].space.n,
]

# This is the spec of the policy input for this group
actor_input_spec = CompositeSpec(
{
group: CompositeSpec(
{
"observation": self.observation_spec[group]["observation"]
.clone()
.to(self.device)
},
shape=(n_agents,),
)
}
)
# This is the spec of the policy output for this group
actor_output_spec = CompositeSpec(
{
group: CompositeSpec(
{"action_value": UnboundedContinuousTensorSpec(shape=logits_shape)},
shape=(n_agents,),
)
}
)
# This is our neural policy
actor_module = model_config.get_model(
input_spec=actor_input_spec,
output_spec=actor_output_spec,
agent_group=group,
input_has_agent_dim=True, # Always true for a policy
n_agents=n_agents,
centralised=False, # Always false for a policy
share_params=self.experiment_config.share_policy_params,
device=self.device,
)
if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
else:
action_mask_key = None

value_module = QValueModule(
action_value_key=(group, "action_value"),
action_mask_key=action_mask_key,
out_keys=[
(group, "action"),
(group, "action_value"),
(group, "chosen_action_value"),
],
spec=self.action_spec[group, "action"],
action_space=None, # We already passed the spec
)

# Here we chain the actor and the value module to get our policy
return TensorDictSequential(actor_module, value_module)

def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:

if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
else:
action_mask_key = None
# Add exploration for collection
greedy = EGreedyModule(
annealing_num_steps=self.experiment_config.get_exploration_anneal_frames(
self.on_policy
),
action_key=(group, "action"),
spec=self.action_spec[(group, "action")],
action_mask_key=action_mask_key,
eps_init=self.experiment_config.exploration_eps_init,
eps_end=self.experiment_config.exploration_eps_end,
)
return TensorDictSequential(*policy_for_loss, greedy)

def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
# Here we make sure that all entries have the desired shape,
# thus, in case there are shared dones, terminated, or rewards, we expande them

keys = list(batch.keys(True, True))
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
nested_reward_key,
batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)),
)

return batch

#####################
# Custom new methods
#####################
def my_custom_method(self):
pass


@dataclass
class CustomIqlConfig(AlgorithmConfig):
# This is a class representing the configuration of your algorithm
# It will be used to validate loaded configs, so that everytime you load this algorithm
# we know exactly which and what parameters to expect with their types

# This is a list of args passed to your algorithm
delay_value: bool = MISSING
loss_function: str = MISSING
my_custom_arg: int = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
# The associated algorithm class
return CustomIqlAlgorithm

@staticmethod
def supports_continuous_actions() -> bool:
# Is it compatible with continuous actions?
return False

@staticmethod
def supports_discrete_actions() -> bool:
# Is it compatible with discrete actions?
return True

@staticmethod
def on_policy() -> bool:
# Should it be trained on or off policy?
return False

@staticmethod
def get_from_yaml(path: Optional[str] = None):
if path is None:
# If get_from_yaml is called without a path,
# we load from benchmarl/conf/algorithm/{CustomIqlConfig.associated_class().__name__}
return CustomIqlConfig(
**AlgorithmConfig._load_from_yaml(
name=CustomIqlConfig.associated_class().__name__,
)
)
else:
# Otherwise, we load it from the given absolute path
return CustomIqlConfig(**read_yaml_config(path))
7 changes: 7 additions & 0 deletions examples/extending/algorithm/customiqlalgorithm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- customalgorithm_config
- _self_

delay_value: bool = True
loss_function: str = "l2"
my_custom_arg: int = 3
File renamed without changes.
16 changes: 16 additions & 0 deletions examples/extending/task/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

# Creating a new task (from a new environemnt)

Here are the steps to create a new algorithm. You can find the custom IQL algorithm
created for this example in [`custom_agorithm.py`](custom_algorithm.py).

1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py).
This is an enum with task entries and some abstract functions you need to implement.

2. Create a `customenv` folder with a yaml configuration file for each of your tasks
You can see [`customenv`](customenv) for an example.
3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and
your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to
override from).
4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the
[`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks.
90 changes: 90 additions & 0 deletions examples/extending/task/custom_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Callable, Dict, List, Optional

from benchmarl.environments.common import Task
from benchmarl.utils import DEVICE_TYPING

from tensordict import TensorDictBase
from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase
from torchrl.envs.libs import YourTorchRLEnvConstructor


class CustomEnvTask(Task):
# Your task names.
# Their config will be loaded from benchmarl/conf/task/customenv

TASK_1 = None # Loaded automatically from benchmarl/conf/task/customenv/task_1
TASK_2 = None # Loaded automatically from benchmarl/conf/task/customenv/task_2

def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
return lambda: YourTorchRLEnvConstructor(
scenario=self.name.lower(),
num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vecotrized)
continuous_actions=continuous_actions, # Ignore this param if your env does not have this choice
seed=seed,
device=device,
categorical_actions=True, # If your env has discrete actions, they need to be categorical (TorchRL can help with this)
**self.config, # Pass the loaded config (this is what is in your yaml
)

def supports_continuous_actions(self) -> bool:
# Does the environment support continuous actions?
return True

def supports_discrete_actions(self) -> bool:
# Does the environment support discrete actions?
return True

def has_render(self, env: EnvBase) -> bool:
# Does the env have a env.render(mode="rgb_array") or env.render() function?
return True

def max_steps(self, env: EnvBase) -> int:
# Maximum number of steps for a rollout during evaluation
return 100

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
# The group map mapping group names to agent names
# The data in the tensordict will havebe presented this way
return {"agents": [agent.name for agent in env.agents]}

def observation_spec(self, env: EnvBase) -> CompositeSpec:
# A spec for the observation.
# Must be a CompositeSpec with one (group_name, "observation") entry per group.
return env.full_observation_spec

def action_spec(self, env: EnvBase) -> CompositeSpec:
# A spec for the observation.
# If provided, must be a CompositeSpec with one (group_name, "action") entry per group.
return env.full_action_spec

def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
# A spec for the state.
# If provided, must be a CompositeSpec with one "state" entry
return None

def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
# A spec for the state.
# If provided, must be a CompositeSpec with one (group_name, "action_mask") entry per group.
return None

def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
# A spec for the observation.
# If provided, must be a CompositeSpec with one (group_name, "info") entry per group (this entry can be composite).
return None

@staticmethod
def env_name() -> str:
# The name of the environment in the benchmarl/conf/task folder
return "customenv"

def log_info(self, batch: TensorDictBase) -> Dict[str, float]:
# Optionally return a str->float dict with extra things to log
# This function has access to the collected batch and its optional
return {}

0 comments on commit 731a3ff

Please sign in to comment.