diff --git a/sheeprl/configs/env/diambra.yaml b/sheeprl/configs/env/diambra.yaml index 242d11e9..d13d87ef 100644 --- a/sheeprl/configs/env/diambra.yaml +++ b/sheeprl/configs/env/diambra.yaml @@ -8,7 +8,7 @@ frame_stack: 4 sync_env: True env: - _target_: sheeprl.envs.diambra_wrapper.DiambraWrapper + _target_: sheeprl.envs.diambra.DiambraWrapper id: ${env.id} action_space: discrete screen_size: ${env.screen_size} diff --git a/sheeprl/envs/diambra_wrapper.py b/sheeprl/envs/diambra.py similarity index 96% rename from sheeprl/envs/diambra_wrapper.py rename to sheeprl/envs/diambra.py index e35026a5..f9c2f2ce 100644 --- a/sheeprl/envs/diambra_wrapper.py +++ b/sheeprl/envs/diambra.py @@ -105,7 +105,7 @@ def _convert_obs(self, obs: Dict[str, Union[int, np.ndarray]]) -> Dict[str, np.n def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: obs, reward, done, infos = self._env.step(action) infos["env_domain"] = "DIAMBRA" - return self._convert_obs(obs), reward, done, False, infos + return self._convert_obs(obs), reward, done or infos.get("env_done", False), False, infos def render(self, mode: str = "rgb_array", **kwargs) -> Optional[Union[RenderFrame, List[RenderFrame]]]: return self._env.render("rgb_array") @@ -114,3 +114,7 @@ def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: return self._convert_obs(self._env.reset()), {"env_domain": "DIAMBRA"} + + def close(self) -> None: + self._env.close() + super().close() diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 8a13cd67..fe25d4ae 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -12,7 +12,7 @@ from sheeprl.utils.imports import _IS_DIAMBRA_ARENA_AVAILABLE, _IS_DIAMBRA_AVAILABLE, _IS_DMC_AVAILABLE if _IS_DIAMBRA_ARENA_AVAILABLE and _IS_DIAMBRA_AVAILABLE: - from sheeprl.envs.diambra_wrapper import DiambraWrapper + from sheeprl.envs.diambra import DiambraWrapper if _IS_DMC_AVAILABLE: pass