Skip to content

Commit

Permalink
from cfg.env.env to cfg.env.wrapper + update configs.md
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Sep 18, 2023
1 parent ee1607a commit 78b1edf
Show file tree
Hide file tree
Showing 19 changed files with 89 additions and 34 deletions.
53 changes: 46 additions & 7 deletions howto/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ defaults:
- buffer: default.yaml
- checkpoint: default.yaml
- env: default.yaml
- exp: ???
- fabric: default.yaml
- hydra: default.yaml
- metric: default.yaml
- hydra: default.yaml
- exp: ???

num_threads: 1
total_steps: ???
Expand All @@ -124,10 +124,6 @@ cnn_keys:
mlp_keys:
encoder: []
decoder: ${mlp_keys.encoder}

# Buffer
buffer:
memmap: True
```
### Algorithms
Expand Down Expand Up @@ -317,7 +313,50 @@ The environment configs can be found under the `sheeprl/configs/env` folders. Sh
* [MineRL (v0.4.4)](https://minerl.readthedocs.io/en/v0.4.4/)
* [MineDojo (v0.1.0)](https://docs.minedojo.org/)

In this way one can easily try out the overall framework with standard RL environments.
In this way one can easily try out the overall framework with standard RL environments. The `default.yaml` config contains all the environment parameters shared by (possibly) all the environments:

```yaml
id: ???
num_envs: 4
frame_stack: 1
sync_env: False
screen_size: 64
action_repeat: 1
grayscale: False
clip_rewards: False
capture_video: True
frame_stack_dilation: 1
max_episode_steps: null
reward_as_observation: False
```

Every custom environment must then "inherit" from this default config, override the particular parameters and define the the `wrapper` field, which is the one that will be directly instantiated at runtime. The `wrapper` field must define all the specific parameters to be passed to the `_target_` function when the wrapper will be instantiated. Take for example the `atari.yaml` config:

```yaml
defaults:
- default
- _self_
# Override from `default` config
action_repeat: 4
id: PongNoFrameskip-v4
max_episode_steps: 27000

# Wrapper to be instantiated
wrapper:
_target_: gymnasium.wrappers.AtariPreprocessing # https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.AtariPreprocessing
env:
_target_: gymnasium.make
id: ${env.id}
render_mode: rgb_array
noop_max: 30
terminal_on_life_loss: False
frame_skip: ${env.action_repeat}
screen_size: ${env.screen_size}
grayscale_obs: ${env.grayscale}
scale_obs: False
grayscale_newaxis: True
```
> **Warning**
>
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def train(

@register_algorithm()
def main(fabric: Fabric, cfg: DictConfig):
if "minedojo" in cfg.env.env._target_.lower():
if "minedojo" in cfg.env.wrapper._target_.lower():
raise ValueError(
"MineDojo is not currently supported by PPO agent, since it does not take "
"into consideration the action masks provided by the environment, but needed "
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def main(fabric: Fabric, cfg: DictConfig):
"`python sheeprl.py exp=ppo_decoupled fabric.devices=2 ...`"
)

if "minedojo" in cfg.env.env._target_.lower():
if "minedojo" in cfg.env.wrapper._target_.lower():
raise ValueError(
"MineDojo is not currently supported by PPO agent, since it does not take "
"into consideration the action masks provided by the environment, but needed "
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def main(fabric: Fabric, cfg: DictConfig):
"`python sheeprl.py exp=sac_decoupled fabric.devices=2 ...`"
)

if "minedojo" in cfg.env.env._target_.lower():
if "minedojo" in cfg.env.wrapper._target_.lower():
raise ValueError(
"MineDojo is not currently supported by PPO agent, since it does not take "
"into consideration the action masks provided by the environment, but needed "
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def train(

@register_algorithm()
def main(fabric: Fabric, cfg: DictConfig):
if "minedojo" in cfg.env.env._target_.lower():
if "minedojo" in cfg.env.wrapper._target_.lower():
raise ValueError(
"MineDojo is not currently supported by SAC-AE agent, since it does not take "
"into consideration the action masks provided by the environment, but needed "
Expand Down
9 changes: 6 additions & 3 deletions sheeprl/configs/env/atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ defaults:
- default
- _self_

id: PongNoFrameskip-v4
# Override from `default` config
action_repeat: 4
id: PongNoFrameskip-v4
max_episode_steps: 27000
env:
_target_: gymnasium.wrappers.AtariPreprocessing

# Wrapper to be instantiated
wrapper:
_target_: gymnasium.wrappers.AtariPreprocessing # https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.AtariPreprocessing
env:
_target_: gymnasium.make
id: ${env.id}
Expand Down
3 changes: 2 additions & 1 deletion sheeprl/configs/env/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
id: ???
num_envs: 4
sync_env: False
frame_stack: 1
sync_env: False
screen_size: 64
action_repeat: 1
grayscale: False
Expand Down
6 changes: 4 additions & 2 deletions sheeprl/configs/env/diambra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ defaults:
- default
- _self_

# Override from `default` config
id: doapp
action_repeat: 1
frame_stack: 4
sync_env: True
action_repeat: 1

env:
# Wrapper to be instantiated
wrapper:
_target_: sheeprl.envs.diambra.DiambraWrapper
id: ${env.id}
action_space: discrete
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/configs/env/dmc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ defaults:
- default
- _self_

# Override from `default` config
id: walker_walk
action_repeat: 1
max_episode_steps: 1000

env:
# Wrapper to be instantiated
wrapper:
_target_: sheeprl.envs.dmc.DMCWrapper
id: ${env.id}
width: ${env.screen_size}
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/configs/env/dummy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- default
- _self_

# Override from `default` config
id: discrete_dummy

env:
# Wrapper to be instantiated
wrapper:
_target_: sheeprl.utils.env.get_dummy_env
id: ${env.id}
4 changes: 3 additions & 1 deletion sheeprl/configs/env/gym.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ defaults:
- default
- _self_

# Override from `default` config
id: CartPole-v1
mask_velocities: False

env:
# Wrapper to be instantiated
wrapper:
_target_: gymnasium.make
id: ${env.id}
render_mode: rgb_array
6 changes: 3 additions & 3 deletions sheeprl/configs/env/minecraft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- default
- _self_

min_pitch: -60
max_pitch: 60
break_speed_multiplier: 100
min_pitch: -60
sticky_jump: 10
sticky_attack: 30
sticky_jump: 10
break_speed_multiplier: 100
4 changes: 3 additions & 1 deletion sheeprl/configs/env/minedojo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ defaults:
- minecraft
- _self_

# Override from `minecraft` config
id: open-ended
action_repeat: 1

env:
# Wrapper to be instantiated
wrapper:
_target_: sheeprl.envs.minedojo.MineDojoWrapper
id: ${env.id}
height: ${env.screen_size}
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/configs/env/minerl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ defaults:
- minecraft
- _self_

# Override from `minecraft` config
id: custom_navigate
action_repeat: 1

env:
# Wrapper to be instantiated
wrapper:
_target_: sheeprl.envs.minerl.MineRLWrapper
id: ${env.id}
height: ${env.screen_size}
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/configs/exp/dreamer_v3_L_doapp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ env:
id: doapp
num_envs: 8
frame_stack: 1
env:
wrapper:
diambra_settings:
characters: Kasumi

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ env:
frame_stack: 1
screen_size: 128
reward_as_observation: True
env:
wrapper:
attack_but_combination: True
diambra_settings:
characters: Kasumi
Expand Down Expand Up @@ -81,4 +81,4 @@ algo:

# Metric
metric:
log_every: 10000
log_every: 10000
2 changes: 1 addition & 1 deletion sheeprl/configs/exp/dreamer_v3_L_navigate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ env:
num_envs: 4
id: custom_navigate
reward_as_observation: True
env:
wrapper:
multihot_inventory: False

# Checkpoint
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ env:
num_envs: 1
max_episode_steps: 1000
id: walker_walk
env:
wrapper:
from_vectors: True
from_pixels: True

Expand Down Expand Up @@ -52,4 +52,4 @@ algo:

# Metric
metric:
log_every: 5000
log_every: 5000
6 changes: 3 additions & 3 deletions sheeprl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def thunk() -> gym.Env:
env_spec = ""

instantiate_kwargs = {}
if "seed" in cfg.env.env:
if "seed" in cfg.env.wrapper:
instantiate_kwargs["seed"] = seed
if "rank" in cfg.env.env:
if "rank" in cfg.env.wrapper:
instantiate_kwargs["rank"] = rank + vector_env_idx
env = hydra.utils.instantiate(cfg.env.env, **instantiate_kwargs)
env = hydra.utils.instantiate(cfg.env.wrapper, **instantiate_kwargs)

# action repeat
if (
Expand Down

0 comments on commit 78b1edf

Please sign in to comment.