diff --git a/test/test_libs.py b/test/test_libs.py index 9e941deb477..4046dabfb8e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1567,6 +1567,9 @@ def test_habitat(self, envname): check_env_specs(env) def test_from_config(self): + import habitat + + cfg = habitat.get_config("benchmark/nav/objectnav/objectnav_hssd-hab.yaml") env = HabitatEnv.from_config(cfg) check_env_specs(env) assert isinstance(env, HabitatEnv) diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index d0403d21cd4..3fc9672dabb 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -9,7 +9,7 @@ from torchrl._utils import _make_ordinal_device from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase -from torchrl.envs.libs.gym import GymEnv, set_gym_backend, GymWrapper +from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend from torchrl.envs.utils import _classproperty _has_habitat = importlib.util.find_spec("habitat") is not None