Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
Browse files Browse the repository at this point in the history
…/resume_from_checkpoint
  • Loading branch information
michele-milesi committed Sep 18, 2023
2 parents dbc1345 + d6aa9ea commit 321c955
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 52 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ repos:
files: \.py$

- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black
name: (black) Format Python code
Expand All @@ -43,7 +43,7 @@ repos:
types: [jupyter]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.287"
rev: "v0.0.288"
hooks:
- id: ruff
args: ["--config", "pyproject.toml", "--fix", "./sheeprl"]
2 changes: 1 addition & 1 deletion sheeprl/configs/env/diambra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ frame_stack: 4
sync_env: True

env:
_target_: sheeprl.envs.diambra_wrapper.DiambraWrapper
_target_: sheeprl.envs.diambra.DiambraWrapper
id: ${env.id}
action_space: discrete
screen_size: ${env.screen_size}
Expand Down
3 changes: 1 addition & 2 deletions sheeprl/configs/env/dmc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ max_episode_steps: 1000
env:
_target_: sheeprl.envs.dmc.DMCWrapper
id: ${env.id}
height: ${env.screen_size}
width: ${env.screen_size}
frame_skip: ${env.action_repeat}
height: ${env.screen_size}
seed: null
from_pixels: True
from_vectors: False
6 changes: 5 additions & 1 deletion sheeprl/envs/diambra_wrapper.py → sheeprl/envs/diambra.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _convert_obs(self, obs: Dict[str, Union[int, np.ndarray]]) -> Dict[str, np.n
def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, reward, done, infos = self._env.step(action)
infos["env_domain"] = "DIAMBRA"
return self._convert_obs(obs), reward, done, False, infos
return self._convert_obs(obs), reward, done or infos.get("env_done", False), False, infos

def render(self, mode: str = "rgb_array", **kwargs) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
return self._env.render("rgb_array")
Expand All @@ -114,3 +114,7 @@ def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[Any, Dict[str, Any]]:
return self._convert_obs(self._env.reset()), {"env_domain": "DIAMBRA"}

def close(self) -> None:
self._env.close()
super().close()
49 changes: 23 additions & 26 deletions sheeprl/envs/dmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
if not _IS_DMC_AVAILABLE:
raise ModuleNotFoundError(_IS_DMC_AVAILABLE)

from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union

import numpy as np
from dm_control import suite
from dm_env import specs
from gymnasium import core, spaces


def _spec_to_box(spec, dtype) -> spaces.Space:
def _spec_to_box(spec, dtype) -> spaces.Box:
def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32
dim = int(np.prod(s.shape))
Expand Down Expand Up @@ -54,7 +54,6 @@ def __init__(
height: int = 84,
width: int = 84,
camera_id: int = 0,
frame_skip: int = 1,
task_kwargs: Optional[Dict[Any, Any]] = None,
environment_kwargs: Optional[Dict[Any, Any]] = None,
channels_first: bool = True,
Expand Down Expand Up @@ -93,9 +92,6 @@ def __init__(
Defaults to 84.
camera_id (int, optional): the id of the camera from where to take the image observation.
Defaults to 0.
frame_skip (int, optional): action repeat value. Given an action, `frame_skip` steps will be performed
in the environment with the given action.
Defaults to 1.
task_kwargs (Optional[Dict[Any, Any]], optional): Optional dict of keyword arguments for the task.
Defaults to None.
environment_kwargs (Optional[Dict[Any, Any]], optional): Optional dict specifying
Expand All @@ -121,7 +117,6 @@ def __init__(
self._height = height
self._width = width
self._camera_id = camera_id
self._frame_skip = frame_skip
self._channels_first = channels_first

# create task
Expand All @@ -137,6 +132,10 @@ def __init__(
self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32)
self._norm_action_space = spaces.Box(low=-1.0, high=1.0, shape=self._true_action_space.shape, dtype=np.float32)

# set the reward range
reward_space = _spec_to_box([self._env.reward_spec()], np.float32)
self._reward_range = (reward_space.low.item(), reward_space.high.item())

# create observation space
if from_pixels:
shape = (3, height, width) if channels_first else (height, width, 3)
Expand All @@ -163,7 +162,7 @@ def __init__(
def __getattr__(self, name):
return getattr(self._env, name)

def _get_obs(self, time_step):
def _get_obs(self, time_step) -> Union[Dict[str, np.ndarray], np.ndarray]:
if self._from_pixels:
rgb_obs = self.render(camera_id=self._camera_id)
if self._channels_first:
Expand All @@ -177,7 +176,7 @@ def _get_obs(self, time_step):
else:
return rgb_obs

def _convert_action(self, action):
def _convert_action(self, action) -> np.ndarray:
action = action.astype(np.float64)
true_delta = self._true_action_space.high - self._true_action_space.low
norm_delta = self._norm_action_space.high - self._norm_action_space.low
Expand All @@ -187,20 +186,20 @@ def _convert_action(self, action):
return action

@property
def observation_space(self):
def observation_space(self) -> Union[spaces.Dict, spaces.Box]:
return self._observation_space

@property
def state_space(self):
def state_space(self) -> spaces.Box:
return self._state_space

@property
def action_space(self):
def action_space(self) -> spaces.Box:
return self._norm_action_space

@property
def reward_range(self):
return 0, self._frame_skip
def reward_range(self) -> Tuple[float, float]:
return self._reward_range

@property
def render_mode(self) -> str:
Expand All @@ -211,33 +210,31 @@ def seed(self, seed: Optional[int] = None):
self._norm_action_space.seed(seed)
self._observation_space.seed(seed)

def step(self, action):
def step(
self, action: Any
) -> Tuple[Union[Dict[str, np.ndarray], np.ndarray], SupportsFloat, bool, bool, Dict[str, Any]]:
assert self._norm_action_space.contains(action)
action = self._convert_action(action)
assert self._true_action_space.contains(action)
reward = 0
extra = {"internal_state": self._env.physics.get_state().copy()}

for _ in range(self._frame_skip):
time_step = self._env.step(action)
reward += time_step.reward or 0
done = time_step.last()
if done:
break
time_step = self._env.step(action)
reward = time_step.reward or 0.0
done = time_step.last()
obs = self._get_obs(time_step)
self.current_state = _flatten_obs(time_step.observation)
extra = {}
extra["discount"] = time_step.discount
extra["internal_state"] = self._env.physics.get_state().copy()
return obs, reward, done, False, extra

def reset(
self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[np.ndarray, Dict[str, Any]]:
) -> Tuple[Union[Dict[str, np.ndarray], np.ndarray], Dict[str, Any]]:
time_step = self._env.reset()
self.current_state = _flatten_obs(time_step.observation)
obs = self._get_obs(time_step)
return obs, {}

def render(self, camera_id: Optional[int] = None):
def render(self, camera_id: Optional[int] = None) -> np.ndarray:
return self._env.physics.render(height=self._height, width=self._width, camera_id=camera_id or self._camera_id)

def close(self):
Expand Down
34 changes: 14 additions & 20 deletions sheeprl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from sheeprl.utils.imports import _IS_DIAMBRA_ARENA_AVAILABLE, _IS_DIAMBRA_AVAILABLE, _IS_DMC_AVAILABLE

if _IS_DIAMBRA_ARENA_AVAILABLE and _IS_DIAMBRA_AVAILABLE:
from sheeprl.envs.diambra_wrapper import DiambraWrapper
from sheeprl.envs.diambra import DiambraWrapper
if _IS_DMC_AVAILABLE:
from sheeprl.envs.dmc import DMCWrapper
pass


def make_env(
Expand Down Expand Up @@ -86,14 +86,11 @@ def thunk() -> gym.Env:
if "rank" in cfg.env.env:
instantiate_kwargs["rank"] = rank + vector_env_idx
env = hydra.utils.instantiate(cfg.env.env, **instantiate_kwargs)
if "mujoco" in env_spec:
env.frame_skip = 0

# action repeat
if (
cfg.env.action_repeat > 1
and "atari" not in env_spec
and (not _IS_DMC_AVAILABLE or not isinstance(env, DMCWrapper))
and (not (_IS_DIAMBRA_ARENA_AVAILABLE and _IS_DIAMBRA_AVAILABLE) or not isinstance(env, DiambraWrapper))
):
env = ActionRepeat(env, cfg.env.action_repeat)
Expand Down Expand Up @@ -154,41 +151,38 @@ def thunk() -> gym.Env:

def transform_obs(obs: Dict[str, Any]):
for k in cnn_keys:
shape = obs[k].shape
current_obs = obs[k]
shape = current_obs.shape
is_3d = len(shape) == 3
is_grayscale = not is_3d or shape[0] == 1 or shape[-1] == 1
channel_first = not is_3d or shape[0] in (1, 3)

# to 3D image
if not is_3d:
obs.update({k: np.expand_dims(obs[k], axis=0)})
current_obs = np.expand_dims(current_obs, axis=0)

# channel last (opencv needs it)
if channel_first:
obs.update({k: obs[k].transpose(1, 2, 0)})
current_obs = np.transpose(current_obs, (1, 2, 0))

# resize
if obs[k].shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size):
obs.update(
{
k: cv2.resize(
obs[k], (cfg.env.screen_size, cfg.env.screen_size), interpolation=cv2.INTER_AREA
)
}
if current_obs.shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size):
current_obs = cv2.resize(
current_obs, (cfg.env.screen_size, cfg.env.screen_size), interpolation=cv2.INTER_AREA
)

# to grayscale
if cfg.env.grayscale and not is_grayscale:
obs.update({k: cv2.cvtColor(obs[k], cv2.COLOR_RGB2GRAY)})
current_obs = cv2.cvtColor(current_obs, cv2.COLOR_RGB2GRAY)

# back to 3D
if len(obs[k].shape) == 2:
obs.update({k: np.expand_dims(obs[k], axis=-1)})
if len(current_obs.shape) == 2:
current_obs = np.expand_dims(current_obs, axis=-1)
if not cfg.env.grayscale:
obs.update({k: np.repeat(obs[k], 3, axis=-1)})
current_obs = np.repeat(current_obs, 3, axis=-1)

# channel first (PyTorch default)
obs.update({k: obs[k].transpose(2, 0, 1)})
obs[k] = current_obs.transpose(2, 0, 1)

return obs

Expand Down

0 comments on commit 321c955

Please sign in to comment.