From 419c7ce05b67b0fd89b62ae0b73b71b3f7a96514 Mon Sep 17 00:00:00 2001 From: Michele Milesi <74559684+michele-milesi@users.noreply.github.com> Date: Mon, 13 May 2024 13:45:04 +0200 Subject: [PATCH] Fix/minedojo (#286) * fix: multiple envs * fix: multi-discrete actions * fix: remove debug prints * fix: remove debug prints * fix: removed MINEDOJO_HEADLESS --- howto/register_external_algorithm.md | 4 +- howto/register_new_algorithm.md | 4 +- notebooks/dreamer_v3_imagination.ipynb | 4 +- sheeprl/algos/a2c/a2c.py | 4 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 2 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 15 ++++-- sheeprl/algos/dreamer_v2/utils.py | 4 +- sheeprl/algos/dreamer_v3/agent.py | 2 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 9 ++-- sheeprl/algos/dreamer_v3/utils.py | 4 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 4 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 10 ++-- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 4 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 4 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 4 +- sheeprl/algos/ppo/ppo.py | 4 +- sheeprl/algos/ppo/ppo_decoupled.py | 4 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 4 +- sheeprl/algos/ppo_recurrent/utils.py | 4 +- sheeprl/algos/sac/sac_decoupled.py | 2 +- sheeprl/configs/exp/dreamer_v3_minedojo.yaml | 57 ++++++++++++++++++++ sheeprl/envs/minedojo.py | 6 ++- tests/test_algos/test_algos.py | 2 +- 24 files changed, 119 insertions(+), 44 deletions(-) create mode 100644 sheeprl/configs/exp/dreamer_v3_minedojo.yaml diff --git a/howto/register_external_algorithm.md b/howto/register_external_algorithm.md index d1062dab..d86bc36c 100644 --- a/howto/register_external_algorithm.md +++ b/howto/register_external_algorithm.md @@ -605,9 +605,9 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]): } actions = agent.module(torch_obs) if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() + real_actions = torch.stack(actions, -1).cpu().numpy() else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() actions = torch.cat(actions, -1).cpu().numpy() # Single environment step diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 04d3c09d..5d6f6b48 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -603,9 +603,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): } actions = player.get_actions(torch_obs) if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() + real_actions = torch.stack(actions, -1).cpu().numpy() else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() actions = torch.cat(actions, -1).cpu().numpy() # Single environment step diff --git a/notebooks/dreamer_v3_imagination.ipynb b/notebooks/dreamer_v3_imagination.ipynb index 56b931e7..e03e5b04 100644 --- a/notebooks/dreamer_v3_imagination.ipynb +++ b/notebooks/dreamer_v3_imagination.ipynb @@ -222,9 +222,9 @@ " real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)\n", " actions = torch.cat(actions, -1).cpu().numpy()\n", " if is_continuous:\n", - " real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n", + " real_actions = torch.stack(real_actions, dim=-1).cpu().numpy()\n", " else:\n", - " real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()\n", + " real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()\n", "\n", " step_data[\"actions\"] = actions.reshape((1, cfg.env.num_envs, -1))\n", " rb_initial.add(step_data, validate_args=cfg.buffer.validate_args)\n", diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index 2002fbe1..e3b23ee1 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -239,9 +239,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs) actions, _, values = player(torch_obs) if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() + real_actions = torch.stack(actions, -1).cpu().numpy() else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy() + real_actions = torch.stack([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy() actions = torch.cat(actions, -1).cpu().numpy() # Single environment step diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index c9fdfbf5..1d69b223 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -581,7 +581,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index f9e251ab..d819a66f 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -145,7 +145,11 @@ def train( # One step of dynamic learning, which take the posterior state, the recurrent state, the action # and the observation and compute the next recurrent, prior and posterior states recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic( - posterior, recurrent_state, data["actions"][i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1] + posterior, + recurrent_state, + data["actions"][i : i + 1], + embedded_obs[i : i + 1], + data["is_first"][i : i + 1], ) recurrent_states[i] = recurrent_state priors_logits[i] = prior_logits @@ -344,7 +348,10 @@ def train( critic_grads = None if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0: critic_grads = fabric.clip_gradients( - module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False + module=critic, + optimizer=critic_optimizer, + max_norm=cfg.algo.critic.clip_gradients, + error_if_nonfinite=False, ) critic_optimizer.step() @@ -606,10 +613,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: real_actions = ( - torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 4abfa58f..3a846858 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -150,9 +150,9 @@ def test( torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: - real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() # Single environment step obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 1650e2ea..02a40a8a 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -913,7 +913,7 @@ def forward( if sampled_action == 15: # Craft action logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf elif i == 2: - mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) + mask["mask_destroy"] = mask["mask_destroy"].expand_as(logits) mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits) for t in range(functional_action.shape[0]): for b in range(functional_action.shape[1]): diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 51b48351..b555f765 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -319,7 +319,10 @@ def train( critic_grads = None if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0: critic_grads = fabric.clip_gradients( - module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False + module=critic, + optimizer=critic_optimizer, + max_norm=cfg.algo.critic.clip_gradients, + error_if_nonfinite=False, ) critic_optimizer.step() @@ -573,10 +576,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() + real_actions = torch.stack(real_actions, dim=-1).cpu().numpy() else: real_actions = ( - torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index 1b73e60c..2fdac419 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -125,9 +125,9 @@ def test( torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: - real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + real_actions = torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() # Single environment step obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 3f66a606..cf2e8b2f 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -605,7 +605,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 7767a9bb..ba8bd4ac 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -260,10 +260,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: real_actions = ( - torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) next_obs, rewards, terminated, truncated, infos = envs.step( real_actions.reshape(envs.action_space.shape) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 72b2fb3b..7bb483cf 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -128,7 +128,11 @@ def train( for i in range(0, sequence_length): recurrent_state, posterior, prior, posterior_logits, prior_logits = world_model.rssm.dynamic( - posterior, recurrent_state, data["actions"][i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1] + posterior, + recurrent_state, + data["actions"][i : i + 1], + embedded_obs[i : i + 1], + data["is_first"][i : i + 1], ) recurrent_states[i] = recurrent_state priors[i] = prior @@ -742,10 +746,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: real_actions = ( - torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index f5203987..1308d8f4 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -280,10 +280,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() + real_actions = torch.stack(real_actions, -1).cpu().numpy() else: real_actions = ( - torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 9339bae9..9ab79630 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -814,10 +814,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() + real_actions = torch.stack(real_actions, dim=-1).cpu().numpy() else: real_actions = ( - torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 7370db43..a29ebce6 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -262,10 +262,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: - real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() + real_actions = torch.stack(real_actions, dim=-1).cpu().numpy() else: real_actions = ( - torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + torch.stack([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 8dfe8812..dd83732b 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -276,9 +276,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) actions, logprobs, values = player(torch_obs) if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() + real_actions = torch.stack(actions, -1).cpu().numpy() else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() actions = torch.cat(actions, -1).cpu().numpy() # Single environment step diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index d41f7911..79f97337 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -206,9 +206,9 @@ def player( torch_obs = prepare_obs(fabric, next_obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) actions, logprobs, values = agent(torch_obs) if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() + real_actions = torch.stack(actions, -1).cpu().numpy() else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() actions = torch.cat(actions, -1).cpu().numpy() # Single environment step diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 8b84d128..bfd5d52a 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -299,9 +299,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs, prev_actions=torch_prev_actions, prev_states=prev_states ) if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() + real_actions = torch.stack(actions, -1).cpu().numpy() else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() torch_actions = torch.cat(actions, dim=-1) actions = torch_actions.cpu().numpy() diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index 90ccdaed..47111ade 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -57,10 +57,10 @@ def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_d # Act greedly through the environment actions, state = agent.get_actions(torch_obs, actions, state, greedy=True) if agent.actor.is_continuous: - real_actions = torch.cat(actions, -1) + real_actions = torch.stack(actions, -1) actions = torch.cat(actions, dim=-1).view(1, 1, -1) else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) + real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1) actions = torch.cat([act for act in actions], dim=-1).view(1, 1, -1) # Single environment step diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 6350bbec..44074f5a 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -189,7 +189,7 @@ def player( torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs) actions = actor(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: diff --git a/sheeprl/configs/exp/dreamer_v3_minedojo.yaml b/sheeprl/configs/exp/dreamer_v3_minedojo.yaml new file mode 100644 index 00000000..23591e49 --- /dev/null +++ b/sheeprl/configs/exp/dreamer_v3_minedojo.yaml @@ -0,0 +1,57 @@ +# @package _global_ + +defaults: + - dreamer_v3 + - override /algo: dreamer_v3_XS + - override /env: minedojo + - _self_ + +# Experiment +seed: 5 +total_steps: 50000000 + +# Environment +env: + num_envs: 2 + id: harvest_milk + reward_as_observation: True + +# Checkpoint +checkpoint: + every: 100000 + +# Buffer +buffer: + checkpoint: True + +# Algorithm +algo: + replay_ratio: 0.015625 + learning_starts: 65536 + actor: + cls: sheeprl.algos.dreamer_v3.agent.MinedojoActor + cnn_keys: + encoder: + - rgb + mlp_keys: + encoder: + - equipment + - inventory + - inventory_delta + - inventory_max + - life_stats + - mask_action_type + - mask_craft_smelt + - mask_destroy + - mask_equip_place + - reward + decoder: + - equipment + - inventory + - inventory_delta + - inventory_max + - life_stats + - mask_action_type + - mask_craft_smelt + - mask_destroy + - mask_equip_place diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 17cf7b0d..57e40703 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -10,6 +10,7 @@ import gymnasium as gym import minedojo +import minedojo.tasks import numpy as np from gymnasium.core import RenderFrame from minedojo.sim import ALL_CRAFT_SMELT_ITEMS, ALL_ITEMS @@ -39,6 +40,7 @@ } ITEM_ID_TO_NAME = dict(enumerate(ALL_ITEMS)) ITEM_NAME_TO_ID = dict(zip(ALL_ITEMS, range(N_ALL_ITEMS))) +ALL_TASKS_SPECS = copy.deepcopy(minedojo.tasks.ALL_TASKS_SPECS) # Minedojo functional actions: # 0: noop @@ -67,7 +69,7 @@ def __init__( self._width = width self._pitch_limits = pitch_limits self._pos = kwargs.get("start_position", None) - self._break_speed_multiplier = kwargs.get("break_speed_multiplier", 100) + self._break_speed_multiplier = kwargs.pop("break_speed_multiplier", 100) self._start_pos = copy.deepcopy(self._pos) self._sticky_attack = 0 if self._break_speed_multiplier > 1 else sticky_attack self._sticky_jump = sticky_jump @@ -84,6 +86,7 @@ def __init__( image_size=(height, width), world_seed=seed, fast_reset=True, + break_speed_multiplier=self._break_speed_multiplier, **kwargs, ) super().__init__(env) @@ -109,6 +112,7 @@ def __init__( ) self._render_mode: str = "rgb_array" self.seed(seed=seed) + minedojo.tasks.ALL_TASKS_SPECS = copy.deepcopy(ALL_TASKS_SPECS) @property def render_mode(self) -> str | None: diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 511a4691..09e98221 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -26,7 +26,7 @@ def standard_args(): "hydra/hydra_logging=disabled", "dry_run=True", "checkpoint.save_last=False", - "env.num_envs=1", + "env.num_envs=2", f"env.sync_env={_IS_WINDOWS}", "env.capture_video=False", "fabric.devices=auto",