diff --git a/examples/extending/algorithm/README.md b/examples/extending/algorithm/README.md index 3cf4323a..4ee6cc92 100644 --- a/examples/extending/algorithm/README.md +++ b/examples/extending/algorithm/README.md @@ -11,8 +11,12 @@ and a associated dataclass to validate loaded configs. in your script. Make sure it has `customalgorithm_config` within its defaults at the top of the file to let hydra know which python dataclass it is associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml) -for an example +for an example. 3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to override from) 4. Add `{"custom_agorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithm.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py) +5. Load it with +```bash +python benchmarl/run.py algorithm=customalgorithm task=... +``` diff --git a/examples/extending/algorithm/custom_algorithm.py b/examples/extending/algorithm/custom_algorithm.py index 79ca4bb8..baaeb0de 100644 --- a/examples/extending/algorithm/custom_algorithm.py +++ b/examples/extending/algorithm/custom_algorithm.py @@ -201,6 +201,13 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: return batch + def process_loss_vals( + self, group: str, loss_vals: TensorDictBase + ) -> TensorDictBase: + # Here you can modify the loss_vals tensordict containing entries loss_name->loss_value + # For example you can sum two entries in a new entry to optimize them together. + return loss_vals + ##################### # Custom new methods ##################### diff --git a/examples/extending/algorithm/customiqlalgorithm.yaml b/examples/extending/algorithm/customiqlalgorithm.yaml index a5a2e12b..e8f3b864 100644 --- a/examples/extending/algorithm/customiqlalgorithm.yaml +++ b/examples/extending/algorithm/customiqlalgorithm.yaml @@ -2,6 +2,6 @@ defaults: - customalgorithm_config - _self_ -delay_value: bool = True -loss_function: str = "l2" -my_custom_arg: int = 3 +delay_value: True +loss_function: "l2" +my_custom_arg: 3 diff --git a/examples/extending/task/README.md b/examples/extending/task/README.md index e4ecd9e0..fa9591d5 100644 --- a/examples/extending/task/README.md +++ b/examples/extending/task/README.md @@ -1,16 +1,19 @@ # Creating a new task (from a new environemnt) -Here are the steps to create a new algorithm. You can find the custom IQL algorithm -created for this example in [`custom_agorithm.py`](custom_algorithm.py). +Here are the steps to create a new task. 1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py). This is an enum with task entries and some abstract functions you need to implement. -2. Create a `customenv` folder with a yaml configuration file for each of your tasks +2. Create a `customenv` folder with a yaml configuration file for each of your tasks. You can see [`customenv`](customenv) for an example. 3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to override from). 4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the [`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks. +5. Load it with +```bash +python benchmarl/run.py task=customenv/task_name algorithm=... +``` diff --git a/examples/extending/task/custom_task.py b/examples/extending/task/custom_task.py index 860dcee3..91e57fd4 100644 --- a/examples/extending/task/custom_task.py +++ b/examples/extending/task/custom_task.py @@ -25,7 +25,7 @@ def get_env_fun( ) -> Callable[[], EnvBase]: return lambda: YourTorchRLEnvConstructor( scenario=self.name.lower(), - num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vecotrized) + num_envs=num_envs, # Number of vectorized envs (do not use this param if the env is not vectorized) continuous_actions=continuous_actions, # Ignore this param if your env does not have this choice seed=seed, device=device,