Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 8, 2023
1 parent f6803c7 commit 2803eae
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 8 deletions.
6 changes: 5 additions & 1 deletion examples/extending/algorithm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ and a associated dataclass to validate loaded configs.
in your script. Make sure it has `customalgorithm_config` within its defaults at
the top of the file to let hydra know which python dataclass it is
associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml)
for an example
for an example.
3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and
your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to
override from)
4. Add `{"custom_agorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithm.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py)
5. Load it with
```bash
python benchmarl/run.py algorithm=customalgorithm task=...
```
7 changes: 7 additions & 0 deletions examples/extending/algorithm/custom_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:

return batch

def process_loss_vals(
self, group: str, loss_vals: TensorDictBase
) -> TensorDictBase:
# Here you can modify the loss_vals tensordict containing entries loss_name->loss_value
# For example you can sum two entries in a new entry to optimize them together.
return loss_vals

#####################
# Custom new methods
#####################
Expand Down
6 changes: 3 additions & 3 deletions examples/extending/algorithm/customiqlalgorithm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ defaults:
- customalgorithm_config
- _self_

delay_value: bool = True
loss_function: str = "l2"
my_custom_arg: int = 3
delay_value: True
loss_function: "l2"
my_custom_arg: 3
9 changes: 6 additions & 3 deletions examples/extending/task/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@

# Creating a new task (from a new environemnt)

Here are the steps to create a new algorithm. You can find the custom IQL algorithm
created for this example in [`custom_agorithm.py`](custom_algorithm.py).
Here are the steps to create a new task.

1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py).
This is an enum with task entries and some abstract functions you need to implement.

2. Create a `customenv` folder with a yaml configuration file for each of your tasks
2. Create a `customenv` folder with a yaml configuration file for each of your tasks.
You can see [`customenv`](customenv) for an example.
3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and
your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to
override from).
4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the
[`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks.
5. Load it with
```bash
python benchmarl/run.py task=customenv/task_name algorithm=...
```
2 changes: 1 addition & 1 deletion examples/extending/task/custom_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_env_fun(
) -> Callable[[], EnvBase]:
return lambda: YourTorchRLEnvConstructor(
scenario=self.name.lower(),
num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vecotrized)
num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vectorized)
continuous_actions=continuous_actions, # Ignore this param if your env does not have this choice
seed=seed,
device=device,
Expand Down

0 comments on commit 2803eae

Please sign in to comment.