Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support interface of environment customization #310

Merged
merged 23 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,25 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.292
rev: v0.3.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 24.3.0
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.15.2
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ sphinx-autoapi
sphinx-autobuild
sphinx-autodoc-typehints
furo
sphinxcontrib-spelling
22 changes: 22 additions & 0 deletions docs/source/envs/custom.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
OmniSafe Customization Interface of Environments
================================================

.. currentmodule:: omnisafe.envs.custom_env

.. autosummary::

CustomEnv

CustomEnv
---------

.. card::
:class-header: sd-bg-success sd-text-white
:class-card: sd-outline-success sd-rounded-1

Documentation
^^^

.. autoclass:: CustomEnv
:members:
:private-members:
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ this project, don't hesitate to ask your question on `the GitHub issue page <htt

start/installation
start/usage
start/env

.. toctree::
:hidden:
Expand Down Expand Up @@ -458,6 +459,7 @@ this project, don't hesitate to ask your question on `the GitHub issue page <htt
:caption: envs api

envs/core
envs/custom
envs/wrapper
envs/safety_gymnasium
envs/mujoco_env
Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,4 @@ UpdateActorCritic
UpdateDynamics
mathbb
meger
Jupyter
73 changes: 73 additions & 0 deletions docs/source/start/env.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
Environments Customization
===========================

OmniSafe supports a flexible environment customization interface. Users only need to make minimal
interface adaptations within the simplest template provided by OmniSafe to complete the environment
customization.

.. note::
The highlight of OmniSafe's environment customization is that **users only need to modify the code at the environment layer**, to enjoy OmniSafe's complete set of training, saving, and data logging mechanisms. This allows users who install from PyPI to use it easily and only focus on the dynamics of the environment.


Get Started with the Simplest Template
--------------------------------------

OmniSafe offers a minimal implementation of an environment template as an example of a customized
environments, :doc:`../envs/custom`.
We recommend reading this template in detail and customizing it based on it.

.. card::
:class-header: sd-bg-success sd-text-white
:class-card: sd-outline-success sd-rounded-1

Frequently Asked Questions
^^^
1. What changes are necessary to embed the environment into OmniSafe?
2. My environment requires specific parameters; can these be integrated into OmniSafe's parameter mechanism?
3. I need to log information during training; how can I achieve this?
4. After embedding the environment, how do I run the algorithms in OmniSafe for training?

For the above questions, we provide a complete Jupyter Notebook example (Please see our tutorial on
GitHub page). We will demonstrate how to start from the most common environments in
`Gymnasium <https://gymnasium.farama.org/>`_ style, implement
environment customization and complete the training process.


Customization of Your Environments
-----------------------------------

From Source Code
^^^^^^^^^^^^^^^^

If you are installing from the source code, you can follow the steps below:

.. card::
:class-header: sd-bg-success sd-text-white
:class-card: sd-outline-success sd-rounded-1

Build from Source Code
^^^
1. Create a new file under `omnisafe/envs/`, for example, `omnisafe/envs/my_env.py`.
2. Customize the environment in `omnisafe/envs/my_env.py`. Assuming the class name is `MyEnv`, and the environment name is `MyEnv-v0`.
3. Add `from .my_env import MyEnv` in `omnisafe/envs/__init__.py`.
4. Run the following command in the `omnisafe/examples` folder:

.. code-block:: bash
:linenos:

python train_policy.py --algo PPOLag --env MyEnv-v0

From PyPI
^^^^^^^^^

.. card::
:class-header: sd-bg-success sd-text-white
:class-card: sd-outline-success sd-rounded-1

Build from PyPI
^^^
1. Customize the environment in any folder. Assuming the class name is `MyEnv`, and the environment name is `MyEnv-v0`.
2. Import OmniSafe and the environment registration decorator.
3. Run the training.

For a short but detailed example, please see `examples/train_from_custom_env.py`
90 changes: 90 additions & 0 deletions examples/train_from_custom_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example and template for environment customization."""

from __future__ import annotations

import random
from typing import Any, ClassVar

import torch
from gymnasium import spaces

import omnisafe
from omnisafe.envs.core import CMDP, env_register


# first, define the environment class.
# the most important thing is to add the `env_register` decorator.
@env_register
class CustomExampleEnv(CMDP):

# define what tasks the environment support.
_support_envs: ClassVar[list[str]] = ['Custom-v0']
muchvo marked this conversation as resolved.
Show resolved Hide resolved

# automatically reset when `terminated` or `truncated`
need_auto_reset_wrapper = True
# set `truncated=True` when the total steps exceed the time limit.
need_time_limit_wrapper = True

def __init__(self, env_id: str, **kwargs: dict[str, Any]) -> None:
self._count = 0
self._num_envs = 1
self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))
self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))

def step(
self,
action: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:
self._count += 1
obs = torch.as_tensor(self._observation_space.sample())
reward = 2 * torch.as_tensor(random.random()) # noqa
cost = 2 * torch.as_tensor(random.random()) # noqa
terminated = torch.as_tensor(random.random() > 0.9) # noqa
truncated = torch.as_tensor(self._count > self.max_episode_steps)
return obs, reward, cost, terminated, truncated, {'final_observation': obs}

@property
def max_episode_steps(self) -> int:
"""The max steps per episode."""
return 10

def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict]:
self.set_seed(seed)
obs = torch.as_tensor(self._observation_space.sample())
self._count = 0
return obs, {}

def set_seed(self, seed: int) -> None:
random.seed(seed)

def close(self) -> None:
pass

def render(self) -> Any:
pass


# Then you can use it like this:
agent = omnisafe.Agent(
'PPOLag',
'Custom-v0',
)
agent.learn()
6 changes: 3 additions & 3 deletions omnisafe/adapter/offpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def rollout( # pylint: disable=too-many-locals
"""
for _ in range(rollout_step):
if use_rand_action:
act = torch.as_tensor(self._env.sample_action(), dtype=torch.float32).to(
self._device,
)
act = (torch.rand(self.action_space.shape) * 2 - 1).unsqueeze(0) # type: ignore
else:
act = agent.step(self._current_obs, deterministic=False)
next_obs, reward, cost, terminated, truncated, info = self.step(act)
Expand Down Expand Up @@ -181,6 +179,8 @@ def _log_metrics(self, logger: Logger, idx: int) -> None:
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
idx (int): The index of the environment.
"""
if hasattr(self._env, 'spec_log'):
self._env.spec_log(logger)
logger.store(
{
'Metrics/EpRet': self._ep_ret[idx],
Expand Down
37 changes: 33 additions & 4 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ def __init__( # pylint: disable=too-many-arguments
self._cfgs: Config = cfgs
self._device: torch.device = get_device(cfgs.train_cfgs.device)
self._env_id: str = env_id
self._env: CMDP = make(env_id, num_envs=num_envs, device=self._device)
self._eval_env: CMDP = make(env_id, num_envs=1, device=self._device)

env_cfgs = {}

if hasattr(self._cfgs, 'env_cfgs') and self._cfgs.env_cfgs is not None:
env_cfgs = self._cfgs.env_cfgs.todict()

self._env: CMDP = make(env_id, num_envs=num_envs, device=self._device, **env_cfgs)
self._eval_env: CMDP = make(env_id, num_envs=1, device=self._device, **env_cfgs)

self._wrapper(
obs_normalize=cfgs.algo_cfgs.obs_normalize,
Expand Down Expand Up @@ -109,8 +115,20 @@ def _wrapper(
cost_normalize (bool, optional): Whether to normalize the cost. Defaults to True.
"""
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, time_limit=1000, device=self._device)
self._eval_env = TimeLimit(self._eval_env, time_limit=1000, device=self._device)
assert (
self._env.max_episode_steps and self._eval_env.max_episode_steps
), 'You must define max_episode_steps as an integer\
or cancel the use of the time_limit wrapper.'
self._env = TimeLimit(
self._env,
time_limit=self._env.max_episode_steps,
device=self._device,
)
self._eval_env = TimeLimit(
self._eval_env,
time_limit=self._eval_env.max_episode_steps,
device=self._device,
)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env, device=self._device)
self._eval_env = AutoReset(self._eval_env, device=self._device)
Expand Down Expand Up @@ -192,3 +210,14 @@ def save(self) -> dict[str, torch.nn.Module]:
The saved components of environment, e.g., ``obs_normalizer``.
"""
return self._env.save()

def close(self) -> None:
"""Close the environment after training."""
self._env.close()

@property
def env_spec_keys(self) -> list[str]:
"""Return the environment specification log."""
if hasattr(self._env, 'env_spec_log'):
return list(self._env.env_spec_log.keys())
return []
2 changes: 2 additions & 0 deletions omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def _log_metrics(self, logger: Logger, idx: int) -> None:
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
idx (int): The index of the environment.
"""
if hasattr(self._env, 'spec_log'):
self._env.spec_log(logger)
logger.store(
{
'Metrics/EpRet': self._ep_ret[idx],
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/algorithms/model_based/base/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _store_real_data( # pylint: disable=too-many-arguments,unused-argument
info (dict[str, Any]): The information from the environment.
"""
done = terminated or truncated
goal_met = False if 'goal_met' not in info else info['goal_met']
goal_met = info.get('goal_met', False)
if not done and not goal_met:
# when goal_met == true:
# current goal position is not related to the last goal position,
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/algorithms/model_based/base/pets.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _update_dynamics_model(
)

def _update_epoch(self) -> None:
...
"""Update function per epoch."""

def _select_action( # pylint: disable=unused-argument
self,
Expand Down Expand Up @@ -384,7 +384,7 @@ def _store_real_data( # pylint: disable=too-many-arguments,unused-argument
info (dict[str, Any]): The information from the environment.
"""
done = terminated or truncated
goal_met = False if 'goal_met' not in info else info['goal_met']
goal_met = info.get('goal_met', False)
if not terminated and not truncated and not goal_met:
# pylint: disable-next=line-too-long
# if goal_met == true, Current goal position is not related to the last goal position, this huge transition will confuse the dynamics model.
Expand Down
4 changes: 4 additions & 0 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def _init_log(self) -> None:
self._logger.register_key('Time/Evaluate')
self._logger.register_key('Time/Epoch')
self._logger.register_key('Time/FPS')
# register environment specific keys
for env_spec_key in self._env.env_spec_keys:
self.logger.register_key(env_spec_key)

def learn(self) -> tuple[float, float, float]:
"""This is main function for algorithm update.
Expand Down Expand Up @@ -321,6 +324,7 @@ def learn(self) -> tuple[float, float, float]:
ep_cost = self._logger.get_stats('Metrics/EpCost')[0]
ep_len = self._logger.get_stats('Metrics/EpLen')[0]
self._logger.close()
self._env.close()

return ep_ret, ep_cost, ep_len

Expand Down
Loading
Loading