From efe7d9cea96e743511e11b2056cbf05b5932364e Mon Sep 17 00:00:00 2001 From: Jiayi Zhou <108712610+Gaiejj@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:24:37 +0800 Subject: [PATCH] feat: support interface of environment customization (#310) --- .pre-commit-config.yaml | 10 +- docs/requirements.txt | 1 + docs/source/envs/custom.rst | 22 + docs/source/index.rst | 2 + docs/source/spelling_wordlist.txt | 1 + docs/source/start/env.rst | 73 + examples/train_from_custom_env.py | 90 ++ omnisafe/adapter/offpolicy_adapter.py | 6 +- omnisafe/adapter/online_adapter.py | 37 +- omnisafe/adapter/onpolicy_adapter.py | 2 + omnisafe/algorithms/model_based/base/loop.py | 2 +- omnisafe/algorithms/model_based/base/pets.py | 4 +- omnisafe/algorithms/off_policy/ddpg.py | 4 + .../on_policy/base/policy_gradient.py | 17 +- omnisafe/configs/on-policy/PPOLag.yaml | 2 + omnisafe/envs/__init__.py | 1 + omnisafe/envs/core.py | 36 +- omnisafe/envs/custom_env.py | 199 +++ omnisafe/envs/mujoco_env.py | 17 +- omnisafe/envs/safety_gymnasium_env.py | 18 +- omnisafe/envs/safety_gymnasium_modelbased.py | 17 +- omnisafe/envs/wrapper.py | 1 + omnisafe/evaluator.py | 19 +- omnisafe/models/actor_critic/actor_critic.py | 2 - .../models/actor_critic/actor_q_critic.py | 2 - omnisafe/utils/config.py | 4 +- omnisafe/utils/tools.py | 4 +- tests/distribution_train.py | 2 +- .../{Simple-v0.npz => Test-v0.npz} | Bin tests/simple_env.py | 14 +- tests/test_env.py | 7 +- tests/test_policy.py | 34 +- tests/test_registry.py | 19 + ....Environment Customization from Zero.ipynb | 1439 ++++++++++++++++ ...ronment Customization from Community.ipynb | 916 +++++++++++ ....Environment Customization from Zero.ipynb | 1440 +++++++++++++++++ ...ronment Customization from Community.ipynb | 903 +++++++++++ 37 files changed, 5253 insertions(+), 114 deletions(-) create mode 100644 docs/source/envs/custom.rst create mode 100644 docs/source/start/env.rst create mode 100644 examples/train_from_custom_env.py create mode 100644 omnisafe/envs/custom_env.py rename tests/saved_source/{Simple-v0.npz => Test-v0.npz} (100%) create mode 100644 tutorials/English/3.Environment Customization from Zero.ipynb create mode 100644 tutorials/English/4.Environment Customization from Community.ipynb create mode 100644 tutorials/zh-cn/3.Environment Customization from Zero.ipynb create mode 100644 tutorials/zh-cn/4.Environment Customization from Community.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8376e71f..0a80ef22d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,25 +29,25 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.292 + rev: v0.3.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 24.3.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + rev: v3.15.2 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: diff --git a/docs/requirements.txt b/docs/requirements.txt index 55f237969..4ee781c54 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -8,3 +8,4 @@ sphinx-autoapi sphinx-autobuild sphinx-autodoc-typehints furo +sphinxcontrib-spelling diff --git a/docs/source/envs/custom.rst b/docs/source/envs/custom.rst new file mode 100644 index 000000000..85e33a1a7 --- /dev/null +++ b/docs/source/envs/custom.rst @@ -0,0 +1,22 @@ +OmniSafe Customization Interface of Environments +================================================ + +.. currentmodule:: omnisafe.envs.custom_env + +.. autosummary:: + + CustomEnv + +CustomEnv +--------- + +.. card:: + :class-header: sd-bg-success sd-text-white + :class-card: sd-outline-success sd-rounded-1 + + Documentation + ^^^ + + .. autoclass:: CustomEnv + :members: + :private-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index e759bebee..792f62052 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -366,6 +366,7 @@ this project, don't hesitate to ask your question on `the GitHub issue page `_ style, implement +environment customization and complete the training process. + + +Customization of Your Environments +----------------------------------- + +From Source Code +^^^^^^^^^^^^^^^^ + +If you are installing from the source code, you can follow the steps below: + +.. card:: + :class-header: sd-bg-success sd-text-white + :class-card: sd-outline-success sd-rounded-1 + + Build from Source Code + ^^^ + 1. Create a new file under `omnisafe/envs/`, for example, `omnisafe/envs/my_env.py`. + 2. Customize the environment in `omnisafe/envs/my_env.py`. Assuming the class name is `MyEnv`, and the environment name is `MyEnv-v0`. + 3. Add `from .my_env import MyEnv` in `omnisafe/envs/__init__.py`. + 4. Run the following command in the `omnisafe/examples` folder: + + .. code-block:: bash + :linenos: + + python train_policy.py --algo PPOLag --env MyEnv-v0 + +From PyPI +^^^^^^^^^ + +.. card:: + :class-header: sd-bg-success sd-text-white + :class-card: sd-outline-success sd-rounded-1 + + Build from PyPI + ^^^ + 1. Customize the environment in any folder. Assuming the class name is `MyEnv`, and the environment name is `MyEnv-v0`. + 2. Import OmniSafe and the environment registration decorator. + 3. Run the training. + + For a short but detailed example, please see `examples/train_from_custom_env.py` diff --git a/examples/train_from_custom_env.py b/examples/train_from_custom_env.py new file mode 100644 index 000000000..f882b877d --- /dev/null +++ b/examples/train_from_custom_env.py @@ -0,0 +1,90 @@ +# Copyright 2024 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example and template for environment customization.""" + +from __future__ import annotations + +import random +from typing import Any, ClassVar + +import torch +from gymnasium import spaces + +import omnisafe +from omnisafe.envs.core import CMDP, env_register + + +# first, define the environment class. +# the most important thing is to add the `env_register` decorator. +@env_register +class CustomExampleEnv(CMDP): + + # define what tasks the environment support. + _support_envs: ClassVar[list[str]] = ['Custom-v0'] + + # automatically reset when `terminated` or `truncated` + need_auto_reset_wrapper = True + # set `truncated=True` when the total steps exceed the time limit. + need_time_limit_wrapper = True + + def __init__(self, env_id: str, **kwargs: dict[str, Any]) -> None: + self._count = 0 + self._num_envs = 1 + self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) + self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,)) + + def step( + self, + action: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]: + self._count += 1 + obs = torch.as_tensor(self._observation_space.sample()) + reward = 2 * torch.as_tensor(random.random()) # noqa + cost = 2 * torch.as_tensor(random.random()) # noqa + terminated = torch.as_tensor(random.random() > 0.9) # noqa + truncated = torch.as_tensor(self._count > self.max_episode_steps) + return obs, reward, cost, terminated, truncated, {'final_observation': obs} + + @property + def max_episode_steps(self) -> int: + """The max steps per episode.""" + return 10 + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, dict]: + self.set_seed(seed) + obs = torch.as_tensor(self._observation_space.sample()) + self._count = 0 + return obs, {} + + def set_seed(self, seed: int) -> None: + random.seed(seed) + + def close(self) -> None: + pass + + def render(self) -> Any: + pass + + +# Then you can use it like this: +agent = omnisafe.Agent( + 'PPOLag', + 'Custom-v0', +) +agent.learn() diff --git a/omnisafe/adapter/offpolicy_adapter.py b/omnisafe/adapter/offpolicy_adapter.py index 39ce1df1f..7137b0e18 100644 --- a/omnisafe/adapter/offpolicy_adapter.py +++ b/omnisafe/adapter/offpolicy_adapter.py @@ -126,9 +126,7 @@ def rollout( # pylint: disable=too-many-locals """ for _ in range(rollout_step): if use_rand_action: - act = torch.as_tensor(self._env.sample_action(), dtype=torch.float32).to( - self._device, - ) + act = (torch.rand(self.action_space.shape) * 2 - 1).unsqueeze(0) # type: ignore else: act = agent.step(self._current_obs, deterministic=False) next_obs, reward, cost, terminated, truncated, info = self.step(act) @@ -181,6 +179,8 @@ def _log_metrics(self, logger: Logger, idx: int) -> None: logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. idx (int): The index of the environment. """ + if hasattr(self._env, 'spec_log'): + self._env.spec_log(logger) logger.store( { 'Metrics/EpRet': self._ep_ret[idx], diff --git a/omnisafe/adapter/online_adapter.py b/omnisafe/adapter/online_adapter.py index 816de8456..7da3c53df 100644 --- a/omnisafe/adapter/online_adapter.py +++ b/omnisafe/adapter/online_adapter.py @@ -62,8 +62,14 @@ def __init__( # pylint: disable=too-many-arguments self._cfgs: Config = cfgs self._device: torch.device = get_device(cfgs.train_cfgs.device) self._env_id: str = env_id - self._env: CMDP = make(env_id, num_envs=num_envs, device=self._device) - self._eval_env: CMDP = make(env_id, num_envs=1, device=self._device) + + env_cfgs = {} + + if hasattr(self._cfgs, 'env_cfgs') and self._cfgs.env_cfgs is not None: + env_cfgs = self._cfgs.env_cfgs.todict() + + self._env: CMDP = make(env_id, num_envs=num_envs, device=self._device, **env_cfgs) + self._eval_env: CMDP = make(env_id, num_envs=1, device=self._device, **env_cfgs) self._wrapper( obs_normalize=cfgs.algo_cfgs.obs_normalize, @@ -109,8 +115,20 @@ def _wrapper( cost_normalize (bool, optional): Whether to normalize the cost. Defaults to True. """ if self._env.need_time_limit_wrapper: - self._env = TimeLimit(self._env, time_limit=1000, device=self._device) - self._eval_env = TimeLimit(self._eval_env, time_limit=1000, device=self._device) + assert ( + self._env.max_episode_steps and self._eval_env.max_episode_steps + ), 'You must define max_episode_steps as an integer\ + or cancel the use of the time_limit wrapper.' + self._env = TimeLimit( + self._env, + time_limit=self._env.max_episode_steps, + device=self._device, + ) + self._eval_env = TimeLimit( + self._eval_env, + time_limit=self._eval_env.max_episode_steps, + device=self._device, + ) if self._env.need_auto_reset_wrapper: self._env = AutoReset(self._env, device=self._device) self._eval_env = AutoReset(self._eval_env, device=self._device) @@ -192,3 +210,14 @@ def save(self) -> dict[str, torch.nn.Module]: The saved components of environment, e.g., ``obs_normalizer``. """ return self._env.save() + + def close(self) -> None: + """Close the environment after training.""" + self._env.close() + + @property + def env_spec_keys(self) -> list[str]: + """Return the environment specification log.""" + if hasattr(self._env, 'env_spec_log'): + return list(self._env.env_spec_log.keys()) + return [] diff --git a/omnisafe/adapter/onpolicy_adapter.py b/omnisafe/adapter/onpolicy_adapter.py index 087ac32b7..1012f138a 100644 --- a/omnisafe/adapter/onpolicy_adapter.py +++ b/omnisafe/adapter/onpolicy_adapter.py @@ -158,6 +158,8 @@ def _log_metrics(self, logger: Logger, idx: int) -> None: logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. idx (int): The index of the environment. """ + if hasattr(self._env, 'spec_log'): + self._env.spec_log(logger) logger.store( { 'Metrics/EpRet': self._ep_ret[idx], diff --git a/omnisafe/algorithms/model_based/base/loop.py b/omnisafe/algorithms/model_based/base/loop.py index b269a9f65..91bf78ab5 100644 --- a/omnisafe/algorithms/model_based/base/loop.py +++ b/omnisafe/algorithms/model_based/base/loop.py @@ -305,7 +305,7 @@ def _store_real_data( # pylint: disable=too-many-arguments,unused-argument info (dict[str, Any]): The information from the environment. """ done = terminated or truncated - goal_met = False if 'goal_met' not in info else info['goal_met'] + goal_met = info.get('goal_met', False) if not done and not goal_met: # when goal_met == true: # current goal position is not related to the last goal position, diff --git a/omnisafe/algorithms/model_based/base/pets.py b/omnisafe/algorithms/model_based/base/pets.py index e4a847341..f40869ca0 100644 --- a/omnisafe/algorithms/model_based/base/pets.py +++ b/omnisafe/algorithms/model_based/base/pets.py @@ -333,7 +333,7 @@ def _update_dynamics_model( ) def _update_epoch(self) -> None: - ... + """Update function per epoch.""" def _select_action( # pylint: disable=unused-argument self, @@ -384,7 +384,7 @@ def _store_real_data( # pylint: disable=too-many-arguments,unused-argument info (dict[str, Any]): The information from the environment. """ done = terminated or truncated - goal_met = False if 'goal_met' not in info else info['goal_met'] + goal_met = info.get('goal_met', False) if not terminated and not truncated and not goal_met: # pylint: disable-next=line-too-long # if goal_met == true, Current goal position is not related to the last goal position, this huge transition will confuse the dynamics model. diff --git a/omnisafe/algorithms/off_policy/ddpg.py b/omnisafe/algorithms/off_policy/ddpg.py index e833689da..e7199c32c 100644 --- a/omnisafe/algorithms/off_policy/ddpg.py +++ b/omnisafe/algorithms/off_policy/ddpg.py @@ -228,6 +228,9 @@ def _init_log(self) -> None: self._logger.register_key('Time/Evaluate') self._logger.register_key('Time/Epoch') self._logger.register_key('Time/FPS') + # register environment specific keys + for env_spec_key in self._env.env_spec_keys: + self.logger.register_key(env_spec_key) def learn(self) -> tuple[float, float, float]: """This is main function for algorithm update. @@ -321,6 +324,7 @@ def learn(self) -> tuple[float, float, float]: ep_cost = self._logger.get_stats('Metrics/EpCost')[0] ep_len = self._logger.get_stats('Metrics/EpLen')[0] self._logger.close() + self._env.close() return ep_ret, ep_cost, ep_len diff --git a/omnisafe/algorithms/on_policy/base/policy_gradient.py b/omnisafe/algorithms/on_policy/base/policy_gradient.py index 01409b574..4c1539178 100644 --- a/omnisafe/algorithms/on_policy/base/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/base/policy_gradient.py @@ -222,6 +222,10 @@ def _init_log(self) -> None: self._logger.register_key('Time/Epoch') self._logger.register_key('Time/FPS') + # register environment specific keys + for env_spec_key in self._env.env_spec_keys: + self.logger.register_key(env_spec_key) + def learn(self) -> tuple[float, float, float]: """This is main function for algorithm update. @@ -268,22 +272,27 @@ def learn(self) -> tuple[float, float, float]: 'Time/Total': (time.time() - start_time), 'Time/Epoch': (time.time() - epoch_time), 'Train/Epoch': epoch, - 'Train/LR': 0.0 - if self._cfgs.model_cfgs.actor.lr is None - else self._actor_critic.actor_scheduler.get_last_lr()[0], + 'Train/LR': ( + 0.0 + if self._cfgs.model_cfgs.actor.lr is None + else self._actor_critic.actor_scheduler.get_last_lr()[0] + ), }, ) self._logger.dump_tabular() # save model to disk - if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0: + if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0 or ( + epoch + 1 + ) == self._cfgs.train_cfgs.epochs: self._logger.torch_save() ep_ret = self._logger.get_stats('Metrics/EpRet')[0] ep_cost = self._logger.get_stats('Metrics/EpCost')[0] ep_len = self._logger.get_stats('Metrics/EpLen')[0] self._logger.close() + self._env.close() return ep_ret, ep_cost, ep_len diff --git a/omnisafe/configs/on-policy/PPOLag.yaml b/omnisafe/configs/on-policy/PPOLag.yaml index 4d673fdbe..b01b07392 100644 --- a/omnisafe/configs/on-policy/PPOLag.yaml +++ b/omnisafe/configs/on-policy/PPOLag.yaml @@ -128,3 +128,5 @@ defaults: lambda_lr: 0.035 # Type of lagrangian optimizer lambda_optimizer: "Adam" + # environment specific configurations + env_cfgs: {} diff --git a/omnisafe/envs/__init__.py b/omnisafe/envs/__init__.py index 7a9f2ea2b..57778938e 100644 --- a/omnisafe/envs/__init__.py +++ b/omnisafe/envs/__init__.py @@ -15,6 +15,7 @@ """Environment API for OmniSafe.""" from omnisafe.envs.core import CMDP, env_register, make, support_envs +from omnisafe.envs.custom_env import CustomEnv from omnisafe.envs.mujoco_env import MujocoEnv from omnisafe.envs.safety_gymnasium_env import SafetyGymnasiumEnv from omnisafe.envs.safety_gymnasium_modelbased import SafetyGymnasiumModelBased diff --git a/omnisafe/envs/core.py b/omnisafe/envs/core.py index 999ac45fe..890a6be84 100644 --- a/omnisafe/envs/core.py +++ b/omnisafe/envs/core.py @@ -82,6 +82,11 @@ def observation_space(self) -> OmnisafeSpace: """The observation space of the environment.""" return self._observation_space + @property + def max_episode_steps(self) -> int | None: + """The max steps per episode.""" + return None + @property def metadata(self) -> dict[str, Any]: """The metadata of the environment.""" @@ -148,14 +153,6 @@ def set_seed(self, seed: int) -> None: seed (int): The seed to use. """ - @abstractmethod - def sample_action(self) -> torch.Tensor: - """Sample an action from the action space. - - Returns: - The sampled action. - """ - @abstractmethod def render(self) -> Any: """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment. @@ -268,14 +265,6 @@ def set_seed(self, seed: int) -> None: """ self._env.set_seed(seed) - def sample_action(self) -> torch.Tensor: - """Sample an action from the action space. - - Returns: - The sampled action. - """ - return self._env.sample_action() - def render(self) -> Any: """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment. @@ -353,6 +342,20 @@ def register(self, env_class: type[CMDP]) -> type[CMDP]: self._register(env_class) return env_class + def unregister(self, env_class: type[CMDP]) -> type[CMDP]: + """Remove the environment from the register. + + Args: + env_class (type[CMDP]): The environment class. + """ + class_name = env_class.__name__ + if class_name not in self._class: + print(f'{class_name} has not been registered yet') + else: + self._class.pop(class_name) + self._support_envs.pop(class_name) + return env_class + def get_class(self, env_id: str, class_name: str | None) -> type[CMDP]: """Get the environment class. @@ -388,6 +391,7 @@ def support_envs(self) -> list[str]: env_register = ENV_REGISTRY.register support_envs = ENV_REGISTRY.support_envs +env_unregister = ENV_REGISTRY.unregister def make(env_id: str, class_name: str | None = None, **kwargs: Any) -> CMDP: diff --git a/omnisafe/envs/custom_env.py b/omnisafe/envs/custom_env.py new file mode 100644 index 000000000..ac64545da --- /dev/null +++ b/omnisafe/envs/custom_env.py @@ -0,0 +1,199 @@ +# Copyright 2024 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example and template for environment customization.""" + +from __future__ import annotations + +import random +from typing import Any, ClassVar + +import numpy as np +import torch +from gymnasium import spaces + +from omnisafe.common.logger import Logger +from omnisafe.envs.core import CMDP, env_register +from omnisafe.typing import OmnisafeSpace + + +@env_register +class CustomEnv(CMDP): + """Simplest environment for the example and template for environment customization. + + If you wish for your environment to become part of the officially supported environments by + OmniSafe, please refer to this document to implement environment embedding. We will welcome + your GitHub pull request. + + Customizing the environment in OmniSafe requires specifying the following parameters: + + Attributes: + _support_envs (ClassVar[list[str]]): A list composed of strings, used to display all task + names supported by the customized environment. For example: ['Simple-v0']. + _action_space: The action space of the task. It can be defined by directly passing an + :class:`OmniSafeSpace` object, or specified in :meth:`__init__` based on the + characteristics of the customized environment. + _observation_space: The observation space of the task. It can be defined by directly + passing an :class:`OmniSafeSpace` object, or specified in :meth:`__init__` based on the + characteristics of the customized environment. + metadata (ClassVar[dict[str, int]]): A class variable containing environment metadata, such + as render FPS. + need_time_limit_wrapper (bool): Whether the environment needs a time limit wrapper. + need_auto_reset_wrapper (bool): Whether the environment needs an auto-reset wrapper. + _num_envs (int): The number of parallel environments. + + .. warning:: + The :class:`omnisafe.adapter.OnlineAdapter`, :class:`omnisafe.adapter.OfflineAdapter`, and + :class:`omnisafe.adapter.ModelBasedAdapter` implemented by OmniSafe use + :class:`omnisafe.envs.wrapper.AutoReset` and :class:`omnisafe.envs.wrapper.TimeLimit` in + algorithm updates. We recommend setting :attr:`need_auto_reset_wrapper` and + :attr:`need_time_limit_wrapper` to ``True``. If you do not want to use these wrappers, you + can add customized logic in the :meth:`step` function of the customized + environment. + """ + + _support_envs: ClassVar[list[str]] = ['Simple-v0'] + _action_space: OmnisafeSpace + _observation_space: OmnisafeSpace + metadata: ClassVar[dict[str, int]] = {} + env_spec_log: dict[str, Any] + + need_auto_reset_wrapper = True + need_time_limit_wrapper = True + _num_envs = 1 + + def __init__( + self, + env_id: str, + **kwargs: Any, # pylint: disable=unused-argument + ) -> None: + """Initialize CustomEnv with the given environment ID and optional keyword arguments. + + .. note:: + Optionally, you can specify some environment-specific information that needs to be + logged. You need to complete this operation in two steps: + + 1. Define the environment information in dictionary format in :meth:`__init__`. + 2. Log the environment information in :meth:`spec_log`. Please note that the logging in + OmniSafe will occur at the end of each episode, so you need to consider how to + reset the logging values for each episode. + + Example: + >>> # First, define the environment information in dictionary format in __init__. + >>> def __init__(self, env_id: str, **kwargs: Any) -> None: + >>> self.env_spec_log = {'Env/Interaction': 0,} + >>> + >>> # Then, log and reset the environment information in spec_log. + >>> def spec_log(self, logger: Logger) -> dict[str, Any]: + >>> logger.store({'Env/Interaction': self.env_spec_log['Env/Interaction']}) + >>> self.env_spec_log['Env/Interaction'] = 0 + + Args: + env_id (str): The environment ID. + **kwargs: Additional keyword arguments. + """ + self._count = 0 + self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) + self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,)) + self._max_episode_steps = 10 + self.env_spec_log = {} + + def step( + self, + action: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]: + """Run one timestep of the environment's dynamics using the agent actions. + + .. note:: + You need to implement dynamic features related to environment interaction here. That is: + + 1. Update the environment state based on the action; + 2. Calculate reward and cost based on the environment state; + 3. Determine whether to terminate based on the environment state; + 4. Record the information you need. + + Args: + action (torch.Tensor): The action from the agent or random. + + Returns: + observation: The agent's observation of the current environment. + reward: The amount of reward returned after previous action. + cost: The amount of cost returned after previous action. + terminated: Whether the episode has ended. + truncated: Whether the episode has been truncated due to a time limit. + info: Some information logged by the environment. + """ + self._count += 1 + obs = torch.as_tensor(self._observation_space.sample()) + reward = 10000 * torch.as_tensor(random.random()) # noqa + cost = 10000 * torch.as_tensor(random.random()) # noqa + terminated = torch.as_tensor(random.random() > 0.9) # noqa + truncated = torch.as_tensor(self._count > self._max_episode_steps) + return obs, reward, cost, terminated, truncated, {'final_observation': obs} + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, dict]: + """Reset the environment. + + Args: + seed (int, optional): The random seed to use for the environment. Defaults to None. + options (dict[str, Any], optional): Additional options. Defaults to None. + + Returns: + tuple[torch.Tensor, dict]: A tuple containing: + - obs (torch.Tensor): The initial observation. + - info (dict): Additional information. + """ + if seed is not None: + self.set_seed(seed) + obs = torch.as_tensor(self._observation_space.sample()) + self._count = 0 + return obs, {} + + @property + def max_episode_steps(self) -> int: + """The max steps per episode.""" + return 10 + + def spec_log(self, logger: Logger) -> None: + """Log specific environment into logger. + + .. note:: + This function will be called after each episode. + + Args: + logger (Logger): The logger to use for logging. + """ + + def set_seed(self, seed: int) -> None: + """Set the random seed for the environment. + + Args: + seed (int): The random seed. + """ + random.seed(seed) + + def render(self) -> Any: + """Render the environment. + + Returns: + Any: An array representing the rendered environment. + """ + return np.zeros((100, 100, 3), dtype=np.uint8) + + def close(self) -> None: + """Close the environment.""" diff --git a/omnisafe/envs/mujoco_env.py b/omnisafe/envs/mujoco_env.py index f5ba67f53..2bee2f13d 100644 --- a/omnisafe/envs/mujoco_env.py +++ b/omnisafe/envs/mujoco_env.py @@ -142,6 +142,11 @@ def step( return obs, reward, cost, terminated, truncated, info + @property + def max_episode_steps(self) -> int: + """The max steps per episode.""" + return self._env.env.spec.max_episode_steps # type: ignore + def reset( self, seed: int | None = None, @@ -168,18 +173,6 @@ def set_seed(self, seed: int) -> None: """ self.reset(seed=seed) - def sample_action(self) -> torch.Tensor: - """Sample a random action. - - Returns: - A random action. - """ - return torch.as_tensor( - self._env.action_space.sample(), - dtype=torch.float32, - device=self._device, - ) - def render(self) -> Any: """Render the environment. diff --git a/omnisafe/envs/safety_gymnasium_env.py b/omnisafe/envs/safety_gymnasium_env.py index f9eb75d20..80460add7 100644 --- a/omnisafe/envs/safety_gymnasium_env.py +++ b/omnisafe/envs/safety_gymnasium_env.py @@ -220,7 +220,6 @@ def reset( seed (int, optional): The random seed. Defaults to None. options (dict[str, Any], optional): The options for the environment. Defaults to None. - Returns: observation: Agent's observation of the current environment. info: Some information logged by the environment. @@ -228,6 +227,11 @@ def reset( obs, info = self._env.reset(seed=seed, options=options) return torch.as_tensor(obs, dtype=torch.float32, device=self._device), info + @property + def max_episode_steps(self) -> int: + """The max steps per episode.""" + return self._env.env.spec.max_episode_steps + def set_seed(self, seed: int) -> None: """Set the seed for the environment. @@ -236,18 +240,6 @@ def set_seed(self, seed: int) -> None: """ self.reset(seed=seed) - def sample_action(self) -> torch.Tensor: - """Sample a random action. - - Returns: - A random action. - """ - return torch.as_tensor( - self._env.action_space.sample(), - dtype=torch.float32, - device=self._device, - ) - def render(self) -> Any: """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment. diff --git a/omnisafe/envs/safety_gymnasium_modelbased.py b/omnisafe/envs/safety_gymnasium_modelbased.py index 3edc7584b..fe5ae5071 100644 --- a/omnisafe/envs/safety_gymnasium_modelbased.py +++ b/omnisafe/envs/safety_gymnasium_modelbased.py @@ -465,6 +465,11 @@ def step( return obs, reward, cost, terminated, truncated, info + @property + def max_episode_steps(self) -> int: + """The max steps per episode.""" + return self._env.env.spec.max_episode_steps + def reset( self, seed: int | None = None, @@ -504,18 +509,6 @@ def set_seed(self, seed: int) -> None: """ self.reset(seed=seed) - def sample_action(self) -> torch.Tensor: - """Sample a random action. - - Returns: - The sampled action. - """ - return torch.as_tensor( - self._env.action_space.sample(), - dtype=torch.float32, - device=self._device, - ) - def render(self) -> Any: """Render the environment. diff --git a/omnisafe/envs/wrapper.py b/omnisafe/envs/wrapper.py index 5fe8525e7..db2e13bdc 100644 --- a/omnisafe/envs/wrapper.py +++ b/omnisafe/envs/wrapper.py @@ -510,6 +510,7 @@ def step( action = self._old_min_action + (self._old_max_action - self._old_min_action) * ( action - self._min_action ) / (self._max_action - self._min_action) + return super().step(action) diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py index 7664fd29e..2fda35197 100644 --- a/omnisafe/evaluator.py +++ b/omnisafe/evaluator.py @@ -54,6 +54,7 @@ class Evaluator: # pylint: disable=too-many-instance-attributes """ _cfgs: Config + _dict_cfgs: dict[str, Any] _save_dir: str _model_name: str _cost_count: torch.Tensor @@ -65,13 +66,9 @@ def __init__( actor: Actor | None = None, actor_critic: ConstraintActorCritic | ConstraintActorQCritic | None = None, dynamics: EnsembleDynamicsModel | None = None, - planner: CEMPlanner - | ARCPlanner - | SafeARCPlanner - | CCEPlanner - | CAPPlanner - | RCEPlanner - | None = None, + planner: ( + CEMPlanner | ARCPlanner | SafeARCPlanner | CCEPlanner | CAPPlanner | RCEPlanner | None + ) = None, render_mode: str = 'rgb_array', ) -> None: """Initialize an instance of :class:`Evaluator`.""" @@ -119,6 +116,7 @@ def __load_cfgs(self, save_dir: str) -> None: raise FileNotFoundError( f'The config file is not found in the save directory{save_dir}.', ) from error + self._dict_cfgs = kwargs self._cfgs = Config.dict2config(kwargs) # pylint: disable-next=too-many-branches @@ -331,6 +329,8 @@ def load_saved( 'width': width, 'height': height, } + if self._dict_cfgs.get('env_cfgs') is not None: + env_kwargs.update(self._dict_cfgs['env_cfgs']) self.__load_model_and_env(save_dir, model_name, env_kwargs) @@ -415,6 +415,8 @@ def evaluate( print(f'Average episode reward: {np.mean(a=episode_rewards)}') print(f'Average episode cost: {np.mean(a=episode_costs)}') print(f'Average episode length: {np.mean(a=episode_lengths)}') + + self._env.close() return ( episode_rewards, episode_costs, @@ -433,7 +435,7 @@ def fps(self) -> int: ), 'The environment must be provided or created before getting the fps.' try: fps = self._env.metadata['render_fps'] - except AttributeError: + except (AttributeError, KeyError): fps = 30 warnings.warn('The fps is not found, use 30 as default.', stacklevel=2) @@ -552,3 +554,4 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc print(f'Average episode reward: {np.mean(episode_rewards)}', file=f) print(f'Average episode cost: {np.mean(episode_costs)}', file=f) print(f'Average episode length: {np.mean(episode_lengths)}', file=f) + self._env.close() diff --git a/omnisafe/models/actor_critic/actor_critic.py b/omnisafe/models/actor_critic/actor_critic.py index 81a9b7a49..697c2cfcc 100644 --- a/omnisafe/models/actor_critic/actor_critic.py +++ b/omnisafe/models/actor_critic/actor_critic.py @@ -104,14 +104,12 @@ def __init__( start_factor=1.0, end_factor=0.0, total_iters=epochs, - verbose=True, ) else: self.actor_scheduler = ConstantLR( self.actor_optimizer, factor=1.0, total_iters=epochs, - verbose=True, ) def step( diff --git a/omnisafe/models/actor_critic/actor_q_critic.py b/omnisafe/models/actor_critic/actor_q_critic.py index 6cad7df43..327424763 100644 --- a/omnisafe/models/actor_critic/actor_q_critic.py +++ b/omnisafe/models/actor_critic/actor_q_critic.py @@ -115,14 +115,12 @@ def __init__( start_factor=1.0, end_factor=0.0, total_iters=epochs, - verbose=True, ) else: self.actor_scheduler = ConstantLR( self.actor_optimizer, factor=1.0, total_iters=epochs, - verbose=True, ) def step(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor: diff --git a/omnisafe/utils/config.py b/omnisafe/utils/config.py index 29698b5f6..3c46872df 100644 --- a/omnisafe/utils/config.py +++ b/omnisafe/utils/config.py @@ -245,7 +245,7 @@ def get_default_kwargs_yaml(algo: str, env_id: str, algo_type: str) -> Config: print(f'Loading {algo}.yaml from {cfg_path}') kwargs = load_yaml(cfg_path) default_kwargs = kwargs['defaults'] - env_spec_kwargs = kwargs[env_id] if env_id in kwargs else None + env_spec_kwargs = kwargs.get(env_id) default_kwargs = Config.dict2config(default_kwargs) @@ -347,7 +347,7 @@ def __check_algo_configs(configs: Config, algo_type: str) -> None: assert isinstance(configs.max_grad_norm, float) and isinstance( configs.critic_norm_coef, float, - ), 'norm must be bool' + ), 'norm must be float' assert ( isinstance(configs.gamma, float) and configs.gamma >= 0.0 and configs.gamma <= 1.0 ), 'gamma must be float, and it values must be [0.0, 1.0]' diff --git a/omnisafe/utils/tools.py b/omnisafe/utils/tools.py index 4e473a253..77710c3ee 100644 --- a/omnisafe/utils/tools.py +++ b/omnisafe/utils/tools.py @@ -273,7 +273,9 @@ def recursive_check_config( for key in config: if key not in default_config and key not in exclude_keys: raise KeyError(f'Invalid key: {key}') - if isinstance(config[key], dict): + if config[key] is None: + return + if isinstance(config[key], dict) and key != 'env_cfgs': recursive_check_config(config[key], default_config[key]) diff --git a/tests/distribution_train.py b/tests/distribution_train.py index e3c15683b..a5823f9e8 100644 --- a/tests/distribution_train.py +++ b/tests/distribution_train.py @@ -20,7 +20,7 @@ if __name__ == '__main__': algo = 'NaturalPG' - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 4096, diff --git a/tests/saved_source/Simple-v0.npz b/tests/saved_source/Test-v0.npz similarity index 100% rename from tests/saved_source/Simple-v0.npz rename to tests/saved_source/Test-v0.npz diff --git a/tests/simple_env.py b/tests/simple_env.py index 7d8b7eba9..7d9c4841c 100644 --- a/tests/simple_env.py +++ b/tests/simple_env.py @@ -28,10 +28,10 @@ @env_register -class SimpleEnv(CMDP): +class TestEnv(CMDP): """Simplest environment for testing.""" - _support_envs: ClassVar[list[str]] = ['Simple-v0'] + _support_envs: ClassVar[list[str]] = ['Test-v0'] metadata: ClassVar[dict[str, int]] = {'render_fps': 30} need_auto_reset_wrapper = True need_time_limit_wrapper = True @@ -61,9 +61,14 @@ def step( reward = 10000 * torch.as_tensor(random.random()) cost = 10000 * torch.as_tensor(random.random()) terminated = torch.as_tensor(random.random() > 0.9) - truncated = torch.as_tensor(self._count > 10) + truncated = torch.as_tensor(self._count > self.max_episode_steps) return obs, reward, cost, terminated, truncated, {'final_observation': obs} + @property + def max_episode_steps(self) -> int: + """The max steps per episode.""" + return 10 + def reset( self, seed: int | None = None, @@ -78,9 +83,6 @@ def reset( def set_seed(self, seed: int) -> None: random.seed(seed) - def sample_action(self) -> torch.Tensor: - return torch.as_tensor(self._action_space.sample()) - def render(self) -> Any: return np.zeros((100, 100, 3), dtype=np.uint8) diff --git a/tests/test_env.py b/tests/test_env.py index 8bafb457a..3e75695c8 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -14,6 +14,7 @@ # ============================================================================== """Test envs.""" +import torch from gymnasium.spaces import Box import helpers @@ -41,7 +42,7 @@ def test_safety_gymnasium(num_envs) -> None: else: assert obs.shape == (obs_space.shape[0],) - act = env.sample_action() + act = torch.zeros(env.action_space.shape, dtype=torch.float32) if num_envs > 1: act = act.repeat(num_envs, 1) @@ -93,7 +94,7 @@ def test_safety_gymnasium_modelbased(num_envs: int) -> None: else: assert obs.shape == (obs_space.shape[0],) - act = env.sample_action() + act = torch.zeros(env.action_space.shape, dtype=torch.float32) if num_envs > 1: act = act.repeat(num_envs, 1) @@ -135,7 +136,7 @@ def test_mujoco(num_envs, env_id) -> None: else: assert obs.shape == (obs_space.shape[0],) - act = env.sample_action() + act = torch.zeros(env.action_space.shape, dtype=torch.float32) if num_envs > 1: act = act.repeat(num_envs, 1) diff --git a/tests/test_policy.py b/tests/test_policy.py index b58db6181..79810d0b9 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -54,7 +54,7 @@ @helpers.parametrize(optim_case=optim_case) def test_cpo(optim_case): - agent = omnisafe.Agent('CPO', 'Simple-v0', custom_cfgs={}) + agent = omnisafe.Agent('CPO', 'Test-v0', custom_cfgs={}) b_grads = torch.Tensor([1]) ep_costs = torch.Tensor([-1]) r = torch.Tensor([0]) @@ -113,7 +113,7 @@ def test_cpo(optim_case): def test_assertion_error(): """Test base algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -163,7 +163,7 @@ def test_assertion_error(): def test_render(): """Test render image""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -188,7 +188,7 @@ def test_render(): @helpers.parametrize(algo=['PETS', 'CCEPETS', 'CAPPETS', 'RCEPETS']) def test_cem_based(algo): """Test model_based algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { @@ -232,7 +232,7 @@ def test_cem_based(algo): @helpers.parametrize(algo=['LOOP', 'SafeLOOP']) def test_loop(algo): """Test model_based algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { @@ -281,7 +281,7 @@ def test_loop(algo): @helpers.parametrize(algo=off_base) def test_off_policy(algo): """Test base algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -310,7 +310,7 @@ def test_off_policy(algo): @helpers.parametrize(algo=off_lag) def test_off_lag_policy(algo): """Test base algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -343,7 +343,7 @@ def test_off_lag_policy(algo): @helpers.parametrize(auto_alpha=auto_alpha) def test_sac_policy(auto_alpha): """Test sac algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -374,7 +374,7 @@ def test_sac_policy(auto_alpha): @helpers.parametrize(auto_alpha=auto_alpha, algo=sac_lag) def test_sac_lag_policy(auto_alpha, algo): """Test sac algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -415,7 +415,7 @@ def test_sac_lag_policy(auto_alpha, algo): ) def test_on_policy(algo): """Test base algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -438,7 +438,7 @@ def test_on_policy(algo): @helpers.parametrize(algo=pid_lagrange_policy) def test_pid(algo): """Test pid algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -466,8 +466,8 @@ def test_pid(algo): ) def test_offline(algo): """Test base algorithms.""" - env_id = 'Simple-v0' - dataset = os.path.join(os.path.dirname(__file__), 'saved_source', 'Simple-v0.npz') + env_id = 'Test-v0' + dataset = os.path.join(os.path.dirname(__file__), 'saved_source', 'Test-v0.npz') custom_cfgs = { 'train_cfgs': { 'total_steps': 4, @@ -488,8 +488,8 @@ def test_offline(algo): ) def test_coptidice(fn_type): """Test coptidice algorithms.""" - env_id = 'Simple-v0' - dataset = os.path.join(os.path.dirname(__file__), 'saved_source', 'Simple-v0.npz') + env_id = 'Test-v0' + dataset = os.path.join(os.path.dirname(__file__), 'saved_source', 'Test-v0.npz') custom_cfgs = { 'train_cfgs': { 'total_steps': 4, @@ -509,7 +509,7 @@ def test_coptidice(fn_type): @helpers.parametrize(algo=['PPO', 'SAC', 'PPOLag']) def test_workflow_for_training(algo): """Test base algorithms.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, @@ -536,7 +536,7 @@ def test_workflow_for_training(algo): def test_std_anealing(): """Test std_anealing.""" - env_id = 'Simple-v0' + env_id = 'Test-v0' custom_cfgs = { 'train_cfgs': { 'total_steps': 200, diff --git a/tests/test_registry.py b/tests/test_registry.py index 7d96e6f35..0e427d310 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -17,6 +17,7 @@ import pytest from omnisafe.algorithms.registry import Registry +from omnisafe.envs.core import CMDP, env_register, env_unregister class TestRegistry: @@ -31,3 +32,21 @@ def test_with_error() -> None: registry.register('test') with pytest.raises(KeyError): registry.get('test') + with pytest.raises(TypeError): + + @env_register + class TestEnv: + name: str = 'test' + idx: int = 0 + + with pytest.raises(ValueError): + + @env_register + class CustomEnv(CMDP): + pass + + @env_register + @env_unregister + @env_unregister + class CustomEnv(CMDP): # noqa + _support_envs = ['Simple-v0'] # noqa diff --git a/tutorials/English/3.Environment Customization from Zero.ipynb b/tutorials/English/3.Environment Customization from Zero.ipynb new file mode 100644 index 000000000..9cc517d08 --- /dev/null +++ b/tutorials/English/3.Environment Customization from Zero.ipynb @@ -0,0 +1,1439 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OmniSafe Tutorial - Environment Customization From Zero\n", + "\n", + "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n", + "\n", + "Documentation: https://omnisafe.readthedocs.io/en/latest/\n", + "\n", + "Safety-Gymnasium: https://www.safety-gymnasium.com/\n", + "\n", + "[Safety-Gymnasium](https://www.safety-gymnasium.com/) is a highly scalable and customizable Safe Reinforcement Learning library, aiming to deliver a good view of benchmarking Safe Reinforcement Learning (Safe RL) algorithms and a more standardized setting of environments. \n", + "\n", + "## Introduction\n", + "\n", + "This section, along with [Tutorial 4: Environment Customization from Community](./4.Environment%20Customization%20from%20Community.ipynb), introduces how to enjoy the full set of training, recording, and saving frameworks provided by OmniSafe for customized environments. This section focuses on introducing beginners to SafeRL on how to create an environment from scratch, while [Tutorial 4: Environment Customization from Community](./4.Gymnasium%20Customization.ipynb) focuses on how to make minimal adaptations to existing community environments, such as [Gymnasium](https://github.com/Farama-Foundation/Gymnasium), to embed them in OmniSafe.\n", + "\n", + "Specifically, this section provides a simplest template for customizing environments. Through this template, you will understand:\n", + "\n", + "- How to create and register an environment in OmniSafe.\n", + "- How to specify customization parameters when creating an environment.\n", + "- How to record environment-specific information.\n", + "\n", + "## Quick Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install via pip (ignore it if you have already installed).\n", + "%pip install omnisafe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install from source (ignore it if you have already installed).\n", + "## clone the repo\n", + "%git clone https://github.com/PKU-Alignment/omnisafe\n", + "%cd omnisafe\n", + "\n", + "## install it\n", + "%pip install -e ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The Simplest Custom Environment Template\n", + "The customized environment of OmniSafe can be implemented through a single file. We will introduce you to the simplest custom environment template, which will serve as a quick start.\n", + "\n", + "### Custom Environment Design\n", + "Here, we will detail the design process of a simple random environment. If you are an expert in RL or an experienced researcher, you can skip this module to [Custom Environment Embedding](#custom-environment-embedding) or [Tutorial 4: Environment Customization from Community](./4.Gymnasium%20Customization.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# import all we need\n", + "from __future__ import annotations\n", + "\n", + "import random\n", + "import omnisafe\n", + "from typing import Any, ClassVar\n", + "\n", + "import torch\n", + "from gymnasium import spaces\n", + "\n", + "from omnisafe.envs.core import CMDP, env_register, env_unregister" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Define environment class\n", + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0'] # Supported task names\n", + "\n", + " need_auto_reset_wrapper = True # Whether `AutoReset` Wrapper is needed\n", + " need_time_limit_wrapper = True # Whether `TimeLimit` Wrapper is needed" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You need to pay attention to the following details in the above code:\n", + "\n", + "- **Task name definition** The supported task names for the environment are provided in `_support_envs`.\n", + "- **Wrapper configuration** Automatic reset and time limit are defined by setting `need_auto_reset_wrapper` and `need_time_limit_wrapper`.\n", + "- **Number of parallel environments** If your environment supports vectorized parallelism, set it through the `_num_envs` parameter." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # Supported task names\n", + "\n", + " need_auto_reset_wrapper = True # Whether `AutoReset` Wrapper is needed\n", + " need_time_limit_wrapper = True # Whether `TimeLimit` Wrapper is needed\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete the `__init__` method definition. Here, you need to specify the action space and observation space of the environment. You need to define these according to the specific task you are currently designing. For example:\n", + "```python\n", + "if env_id == 'Example-v0':\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "elif env_id == 'Example-v1':\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + "else:\n", + " raise NotImplementedError\n", + "```\n", + "**Note:** As it is necessary to provide a standard interface for the higher-level modules, please follow these two variable names, i.e., `self._observation_space` and `self._action_space`, when designing the environment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete the definition of methods related to environment initialization. `reset` and `set_seed` are the standard interfaces for OmniSafe environment initialization. Where `reset` resets the environment state and the step counter. Meanwhile, `set_seed` ensures the reproducibility of experiments by setting the random seed. The `max_episode_steps` method, decorated with `@property`, is used to pass the maximum number of steps per episode that need to be limited to the `TimeLimit` Wrapper. The code is as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # Supported task names\n", + "\n", + " need_auto_reset_wrapper = True # Whether `AutoReset` Wrapper is needed\n", + " need_time_limit_wrapper = True # Whether `TimeLimit` Wrapper is needed\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete the definition of functional methods. The `render` method is used for rendering the environment; the `close` method is used for cleanup after training ends." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # Supported task names\n", + "\n", + " need_auto_reset_wrapper = True # Whether `AutoReset` Wrapper is needed\n", + " need_time_limit_wrapper = True # Whether `TimeLimit` Wrapper is needed\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10\n", + "\n", + " def render(self) -> Any:\n", + " pass\n", + "\n", + " def close(self) -> None:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete the definition of the step method. Here is the core interaction logic of your customized environment. You only need to adjust according to the data input and output format in this example. You can also directly change the random interaction dynamics in this example to the dynamics of your environment." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # Supported task names\n", + "\n", + " need_auto_reset_wrapper = True # Whether `AutoReset` Wrapper is needed\n", + " need_time_limit_wrapper = True # Whether `TimeLimit` Wrapper is needed\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10\n", + "\n", + " def render(self) -> Any:\n", + " pass\n", + "\n", + " def close(self) -> None:\n", + " pass\n", + "\n", + " def step(\n", + " self,\n", + " action: torch.Tensor,\n", + " ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:\n", + " self._count += 1\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " reward = 2 * torch.as_tensor(random.random())\n", + " cost = 2 * torch.as_tensor(random.random())\n", + " terminated = torch.as_tensor(random.random() > 0.9)\n", + " truncated = torch.as_tensor(self._count > 10)\n", + " return obs, reward, cost, terminated, truncated, {'final_observation': obs}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's try to run the environment for 10 time steps and observe the interaction information." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------\n", + "obs: tensor([0.5903, 0.9825, 0.6966])\n", + "reward: 1.6888437271118164\n", + "cost: 1.5159088373184204\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([-0.0615, 0.8932, -0.1051])\n", + "reward: 0.5178334712982178\n", + "cost: 1.0225493907928467\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.7570, -0.0613, 0.9682])\n", + "reward: 1.5675971508026123\n", + "cost: 0.6066254377365112\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.1937, 0.5437, -0.4663])\n", + "reward: 1.1667640209197998\n", + "cost: 1.8162257671356201\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([-0.9458, -0.1812, 0.4118])\n", + "reward: 0.5636757016181946\n", + "cost: 1.511608362197876\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([-0.9290, -0.0350, 0.3893])\n", + "reward: 0.5010126829147339\n", + "cost: 1.8194924592971802\n", + "terminated: True\n", + "truncated: False\n", + "********************\n" + ] + } + ], + "source": [ + "env = ExampleEnv(env_id='Example-v0')\n", + "env.reset(seed=0)\n", + "while True:\n", + " action = env.action_space.sample()\n", + " obs, reward, cost, terminated, truncated, info = env.step(action)\n", + " print('-' * 20)\n", + " print(f'obs: {obs}')\n", + " print(f'reward: {reward}')\n", + " print(f'cost: {cost}')\n", + " print(f'terminated: {terminated}')\n", + " print(f'truncated: {truncated}')\n", + " print('*' * 20)\n", + " if terminated or truncated:\n", + " break\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Congratulations! You have successfully completed the basic environment definition. Next, we will introduce how to register this environment into OmniSafe, and implement steps such as environment parameter passing, interaction information recording, algorithm training, and result saving." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom Environment Embedding\n", + "\n", + "### Quick Training\n", + "\n", + "Thanks to the carefully designed registration mechanism of OmniSafe, we only need one decorator to register this environment into the OmniSafe's environment list." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "@env_register\n", + "class ExampleEnv(ExampleEnv):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Registering an environment with the same name will cause an error, due to **environment name conflict**." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "class CustomExampleEnv(ExampleEnv):\n", + " example_configs = 1\n", + "\n", + "\n", + "env = CustomExampleEnv('Example-v0')\n", + "env.example_configs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So, you need to manually unregister the environment first." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "@env_unregister\n", + "class CustomExampleEnv(ExampleEnv):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Afterwards, you can re-register the environment. In this tutorial, we will nest both the `env_register` and `env_unregister` decorators together. This is to avoid errors caused by repeated registration of the environment, ensuring that the environment is registered only once, so users can modify and run the code multiple times while reading this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CustomExampleEnv has not been registered yet\n" + ] + }, + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class CustomExampleEnv(ExampleEnv):\n", + " example_configs = 2\n", + "\n", + "\n", + "env = CustomExampleEnv('Example-v0')\n", + "env.example_configs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Subsequently, you can use the algorithms in OmniSafe to train this custom environment." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Example-v0}/seed-000-2024-04-09-15-08-37/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-08-37/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
+       "for Jupyter support\n",
+       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
+       "
\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 5.625942230224609       │\n",
+       "│ Metrics/EpCost                 │ 6.960921287536621       │\n",
+       "│ Metrics/EpLen                  │ 5.0                     │\n",
+       "│ Train/Epoch                    │ 0.0                     │\n",
+       "│ Train/Entropy                  │ 1.4189385175704956      │\n",
+       "│ Train/KL                       │ 0.0002234023268101737   │\n",
+       "│ Train/StopIter                 │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
+       "│ Train/LR                       │ 0.00019999999494757503  │\n",
+       "│ Train/PolicyStd                │ 1.0                     │\n",
+       "│ TotalEnvSteps                  │ 10.0                    │\n",
+       "│ Loss/Loss_pi                   │ 7.748603536583687e-08   │\n",
+       "│ Loss/Loss_pi/Delta             │ 7.748603536583687e-08   │\n",
+       "│ Value/Adv                      │ -1.7881394143159923e-08 │\n",
+       "│ Loss/Loss_reward_critic        │ 10.457597732543945      │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 10.457597732543945      │\n",
+       "│ Value/reward                   │ -0.012156231328845024   │\n",
+       "│ Loss/Loss_cost_critic          │ 18.316673278808594      │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 18.316673278808594      │\n",
+       "│ Value/cost                     │ 0.1599183827638626      │\n",
+       "│ Time/Total                     │ 0.03895211219787598     │\n",
+       "│ Time/Rollout                   │ 0.021677017211914062    │\n",
+       "│ Time/Update                    │ 0.01619410514831543     │\n",
+       "│ Time/Epoch                     │ 0.0379033088684082      │\n",
+       "│ Time/FPS                       │ 263.8358459472656       │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
+       "└────────────────────────────────┴─────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.0002234023268101737 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.00019999999494757503 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ 7.748603536583687e-08 │\n", + "│ Loss/Loss_pi/Delta │ 7.748603536583687e-08 │\n", + "│ Value/Adv │ -1.7881394143159923e-08 │\n", + "│ Loss/Loss_reward_critic │ 10.457597732543945 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.457597732543945 │\n", + "│ Value/reward │ -0.012156231328845024 │\n", + "│ Loss/Loss_cost_critic │ 18.316673278808594 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.316673278808594 │\n", + "│ Value/cost │ 0.1599183827638626 │\n", + "│ Time/Total │ 0.03895211219787598 │\n", + "│ Time/Rollout │ 0.021677017211914062 │\n", + "│ Time/Update │ 0.01619410514831543 │\n", + "│ Time/Epoch │ 0.0379033088684082 │\n", + "│ Time/FPS │ 263.8358459472656 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 10.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m10.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                  ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 7.8531365394592285     │\n",
+       "│ Metrics/EpCost                 │ 7.931504726409912      │\n",
+       "│ Metrics/EpLen                  │ 6.666666507720947      │\n",
+       "│ Train/Epoch                    │ 1.0                    │\n",
+       "│ Train/Entropy                  │ 1.4192386865615845     │\n",
+       "│ Train/KL                       │ 8.405959670199081e-05  │\n",
+       "│ Train/StopIter                 │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Mean         │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Min          │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Max          │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                    │\n",
+       "│ Train/LR                       │ 9.999999747378752e-05  │\n",
+       "│ Train/PolicyStd                │ 1.0003000497817993     │\n",
+       "│ TotalEnvSteps                  │ 20.0                   │\n",
+       "│ Loss/Loss_pi                   │ -8.940696716308594e-08 │\n",
+       "│ Loss/Loss_pi/Delta             │ -1.668930025289228e-07 │\n",
+       "│ Value/Adv                      │ 8.940696716308594e-08  │\n",
+       "│ Loss/Loss_reward_critic        │ 37.962928771972656     │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 27.50533103942871      │\n",
+       "│ Value/reward                   │ -0.00784378219395876   │\n",
+       "│ Loss/Loss_cost_critic          │ 25.662063598632812     │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 7.345390319824219      │\n",
+       "│ Value/cost                     │ 0.11082335561513901    │\n",
+       "│ Time/Total                     │ 0.08216094970703125    │\n",
+       "│ Time/Rollout                   │ 0.01664590835571289    │\n",
+       "│ Time/Update                    │ 0.013554811477661133   │\n",
+       "│ Time/Epoch                     │ 0.03022909164428711    │\n",
+       "│ Time/FPS                       │ 330.8123779296875      │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                    │\n",
+       "└────────────────────────────────┴────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 7.8531365394592285 │\n", + "│ Metrics/EpCost │ 7.931504726409912 │\n", + "│ Metrics/EpLen │ 6.666666507720947 │\n", + "│ Train/Epoch │ 1.0 │\n", + "│ Train/Entropy │ 1.4192386865615845 │\n", + "│ Train/KL │ 8.405959670199081e-05 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 9.999999747378752e-05 │\n", + "│ Train/PolicyStd │ 1.0003000497817993 │\n", + "│ TotalEnvSteps │ 20.0 │\n", + "│ Loss/Loss_pi │ -8.940696716308594e-08 │\n", + "│ Loss/Loss_pi/Delta │ -1.668930025289228e-07 │\n", + "│ Value/Adv │ 8.940696716308594e-08 │\n", + "│ Loss/Loss_reward_critic │ 37.962928771972656 │\n", + "│ Loss/Loss_reward_critic/Delta │ 27.50533103942871 │\n", + "│ Value/reward │ -0.00784378219395876 │\n", + "│ Loss/Loss_cost_critic │ 25.662063598632812 │\n", + "│ Loss/Loss_cost_critic/Delta │ 7.345390319824219 │\n", + "│ Value/cost │ 0.11082335561513901 │\n", + "│ Time/Total │ 0.08216094970703125 │\n", + "│ Time/Rollout │ 0.01664590835571289 │\n", + "│ Time/Update │ 0.013554811477661133 │\n", + "│ Time/Epoch │ 0.03022909164428711 │\n", + "│ Time/FPS │ 330.8123779296875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 9.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m9.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 6.297085762023926       │\n",
+       "│ Metrics/EpCost                 │ 6.2187700271606445      │\n",
+       "│ Metrics/EpLen                  │ 5.25                    │\n",
+       "│ Train/Epoch                    │ 2.0                     │\n",
+       "│ Train/Entropy                  │ 1.419387936592102       │\n",
+       "│ Train/KL                       │ 6.185231995914364e-06   │\n",
+       "│ Train/StopIter                 │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Mean         │ 0.9999998211860657      │\n",
+       "│ Train/PolicyRatio/Min          │ 0.9999998211860657      │\n",
+       "│ Train/PolicyRatio/Max          │ 0.9999998211860657      │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
+       "│ Train/LR                       │ 0.0                     │\n",
+       "│ Train/PolicyStd                │ 1.0004496574401855      │\n",
+       "│ TotalEnvSteps                  │ 30.0                    │\n",
+       "│ Loss/Loss_pi                   │ 7.152557657263969e-08   │\n",
+       "│ Loss/Loss_pi/Delta             │ 1.6093254373572563e-07  │\n",
+       "│ Value/Adv                      │ -1.4305115314527939e-07 │\n",
+       "│ Loss/Loss_reward_critic        │ 34.879573822021484      │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ -3.083354949951172      │\n",
+       "│ Value/reward                   │ 0.020589731633663177    │\n",
+       "│ Loss/Loss_cost_critic          │ 27.62775230407715       │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 1.965688705444336       │\n",
+       "│ Value/cost                     │ 0.13300421833992004     │\n",
+       "│ Time/Total                     │ 0.12445831298828125     │\n",
+       "│ Time/Rollout                   │ 0.0154266357421875      │\n",
+       "│ Time/Update                    │ 0.009746313095092773    │\n",
+       "│ Time/Epoch                     │ 0.02520155906677246     │\n",
+       "│ Time/FPS                       │ 396.81585693359375      │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
+       "└────────────────────────────────┴─────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 6.297085762023926 │\n", + "│ Metrics/EpCost │ 6.2187700271606445 │\n", + "│ Metrics/EpLen │ 5.25 │\n", + "│ Train/Epoch │ 2.0 │\n", + "│ Train/Entropy │ 1.419387936592102 │\n", + "│ Train/KL │ 6.185231995914364e-06 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Min │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Max │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0004496574401855 │\n", + "│ TotalEnvSteps │ 30.0 │\n", + "│ Loss/Loss_pi │ 7.152557657263969e-08 │\n", + "│ Loss/Loss_pi/Delta │ 1.6093254373572563e-07 │\n", + "│ Value/Adv │ -1.4305115314527939e-07 │\n", + "│ Loss/Loss_reward_critic │ 34.879573822021484 │\n", + "│ Loss/Loss_reward_critic/Delta │ -3.083354949951172 │\n", + "│ Value/reward │ 0.020589731633663177 │\n", + "│ Loss/Loss_cost_critic │ 27.62775230407715 │\n", + "│ Loss/Loss_cost_critic/Delta │ 1.965688705444336 │\n", + "│ Value/cost │ 0.13300421833992004 │\n", + "│ Time/Total │ 0.12445831298828125 │\n", + "│ Time/Rollout │ 0.0154266357421875 │\n", + "│ Time/Update │ 0.009746313095092773 │\n", + "│ Time/Epoch │ 0.02520155906677246 │\n", + "│ Time/FPS │ 396.81585693359375 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(6.297085762023926, 6.2187700271606445, 5.25)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 30,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 10,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Example-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Well done! We have completed the embedding and training of this customized environment. Next, we will further explore how to specify hyperparameters for the environment.\n", + "\n", + "### Parameter Setting\n", + "\n", + "Starting with a new example environment, assume this environment requires a parameter named `num_agents`. We will show how to complete the parameter setting without modifying OmniSafe's code." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NewExampleEnv has not been registered yet\n" + ] + } + ], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class NewExampleEnv(ExampleEnv): # make a new environment\n", + " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n", + " num_agents: ClassVar[int] = 1\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n", + " self.num_agents = kwargs.get('num_agents', 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the `num_agents` parameter is set to a default value: `1`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_env = NewExampleEnv('NewExample-v0')\n", + "new_env.num_agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below we will show how to modify this parameter through OmniSafe's interface and train:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-08-46/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-08-46/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'env_cfgs': {'num_agents': 2}})\n", + "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n", + "agent.agent._env._env.num_agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Excellent! We have set `num_agents` to 2. This means we have successfully implemented hyperparameter setting without modifying the code.\n", + "\n", + "### Training Information Recording\n", + "\n", + "While running the training code, you may have noticed that OmniSafe records training information through `Logger`, for example:\n", + "\n", + "```bash\n", + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 2.046875 │\n", + "│ Metrics/EpCost │ 2.89453125 │\n", + "│ Metrics/EpLen │ 3.25 │\n", + "│ Train/Epoch │ 3.0 │\n", + "...\n", + "```\n", + "So, can we output information from the environment into the log? The answer is yes, and this process also does not require modifying OmniSafe's code. You only need to implement two standard interfaces:\n", + "1. In the `__init__` function, add the information you want to output to `self.env_spec_log`.\n", + "2. Instantiate the `spec_log` function to record the required information.\n", + "\n", + "**Please note:** Currently, OmniSafe only supports recording this information at the end of each epoch, not after each step." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class NewExampleEnv(ExampleEnv):\n", + " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n", + "\n", + " # define what to log\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n", + " self.env_spec_log = {'Env/Success_counts': 0}\n", + "\n", + " # interact with the environment and log\n", + " def step(self, action):\n", + " obs, reward, cost, terminated, truncated, info = super().step(action)\n", + " success = int(reward > cost)\n", + " self.env_spec_log['Env/Success_counts'] += success\n", + " return obs, reward, cost, terminated, truncated, info\n", + "\n", + " # write to logger\n", + " def spec_log(self, logger) -> dict[str, Any]:\n", + " logger.store({'Env/Success_counts': self.env_spec_log['Env/Success_counts']})\n", + " self.env_spec_log['Env/Success_counts'] = 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we will briefly train and observe whether this information has been successfully recorded." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-08-52/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-08-52/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 5.625942230224609       │\n",
+       "│ Metrics/EpCost                 │ 6.960921287536621       │\n",
+       "│ Metrics/EpLen                  │ 5.0                     │\n",
+       "│ Train/Epoch                    │ 0.0                     │\n",
+       "│ Train/Entropy                  │ 1.4189385175704956      │\n",
+       "│ Train/KL                       │ 0.00026566203450784087  │\n",
+       "│ Train/StopIter                 │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
+       "│ Train/LR                       │ 0.0                     │\n",
+       "│ Train/PolicyStd                │ 1.0                     │\n",
+       "│ TotalEnvSteps                  │ 10.0                    │\n",
+       "│ Loss/Loss_pi                   │ -2.9802322387695312e-08 │\n",
+       "│ Loss/Loss_pi/Delta             │ -2.9802322387695312e-08 │\n",
+       "│ Value/Adv                      │ 5.9604645663569045e-09  │\n",
+       "│ Loss/Loss_reward_critic        │ 10.46424674987793       │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 10.46424674987793       │\n",
+       "│ Value/reward                   │ -0.017885426059365273   │\n",
+       "│ Loss/Loss_cost_critic          │ 18.490144729614258      │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 18.490144729614258      │\n",
+       "│ Value/cost                     │ 0.13730722665786743     │\n",
+       "│ Time/Total                     │ 0.0326535701751709      │\n",
+       "│ Time/Rollout                   │ 0.019308805465698242    │\n",
+       "│ Time/Update                    │ 0.012392044067382812    │\n",
+       "│ Time/Epoch                     │ 0.03173708915710449     │\n",
+       "│ Time/FPS                       │ 315.0982360839844       │\n",
+       "│ Env/Success_counts             │ 1.5                     │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
+       "└────────────────────────────────┴─────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00026566203450784087 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -2.9802322387695312e-08 │\n", + "│ Loss/Loss_pi/Delta │ -2.9802322387695312e-08 │\n", + "│ Value/Adv │ 5.9604645663569045e-09 │\n", + "│ Loss/Loss_reward_critic │ 10.46424674987793 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.46424674987793 │\n", + "│ Value/reward │ -0.017885426059365273 │\n", + "│ Loss/Loss_cost_critic │ 18.490144729614258 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.490144729614258 │\n", + "│ Value/cost │ 0.13730722665786743 │\n", + "│ Time/Total │ 0.0326535701751709 │\n", + "│ Time/Rollout │ 0.019308805465698242 │\n", + "│ Time/Update │ 0.012392044067382812 │\n", + "│ Time/Epoch │ 0.03173708915710449 │\n", + "│ Time/FPS │ 315.0982360839844 │\n", + "│ Env/Success_counts │ 1.5 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(5.625942230224609, 6.960921287536621, 5.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'train_cfgs': {'total_steps': 10}})\n", + "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nice! The above code has outputted the environment-specific information `Env/Success_counts` to the terminal. This process does not require any modifications to the original code.\n", + "\n", + "## Summary\n", + "OmniSafe aims to become the foundational software for safe reinforcement learning. We will continue to refine the environmental interface standards of OmniSafe, enabling it to adapt to various safe reinforcement learning tasks and empower diverse safety scenarios." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/English/4.Environment Customization from Community.ipynb b/tutorials/English/4.Environment Customization from Community.ipynb new file mode 100644 index 000000000..d5123f7c9 --- /dev/null +++ b/tutorials/English/4.Environment Customization from Community.ipynb @@ -0,0 +1,916 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OmniSafe Tutorial - Environment Customization from Community\n", + "\n", + "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n", + "\n", + "Documentation: https://omnisafe.readthedocs.io/en/latest/\n", + "\n", + "Gymnasium: https://github.com/Farama-Foundation/Gymnasium\n", + "\n", + "[Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API.\n", + "\n", + "## Introduction\n", + "\n", + "In this section, we will introduce how to embed an existing environment from the community into OmniSafe. The series of tasks provided by [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) have been widely applied in reinforcement learning. Specifically, this section will use [Pendulum-v1](https://gymnasium.farama.org/environments/classic_control/pendulum/) as an example to show how to embed Gymnasium's tasks into OmniSafe.\n", + "\n", + "## Quick Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install via pip (ignore it if you have already installed).\n", + "%pip install omnisafe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install from source (ignore it if you have already installed).\n", + "## clone the repo\n", + "%git clone https://github.com/PKU-Alignment/omnisafe\n", + "%cd omnisafe\n", + "\n", + "## install it\n", + "%pip install -e ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gymnasium Task Embedding\n", + "The core part for environment embedding is to provide sufficient static or dynamic information for SafeRL agent interaction and training. This section will detail the variables that must be defined for embedding environments and the corresponding standards. We will first present the entire embedding process in the order of code organization, giving a preliminary understanding. Then, we will review all the codes, summarize, and organize the adaptations you need to make when customizing your environment.\n", + "\n", + "### Quick Start\n", + "First, import all external variables required for this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# import all we need\n", + "from __future__ import annotations\n", + "\n", + "from typing import Any, ClassVar\n", + "import gymnasium\n", + "import torch\n", + "import numpy as np\n", + "import omnisafe\n", + "\n", + "from omnisafe.envs.core import CMDP, env_register, env_unregister\n", + "from omnisafe.typing import DEVICE_CPU" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, create a class named `ExampleMuJoCoEnv`, which needs to inherit from `CMDP`. (This is because we want to transform the environment's interaction form into the CMDP paradigm. You can define new abstract classes as needed to implement new paradigms)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleMuJoCoEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Pendulum-v1'] # Supported task names\n", + "\n", + " need_auto_reset_wrapper = True # Whether `AutoReset` Wrapper is needed\n", + " need_time_limit_wrapper = True # Whether `TimeLimit` Wrapper is needed\n", + "\n", + " def __init__(\n", + " self,\n", + " env_id: str,\n", + " num_envs: int = 1,\n", + " device: torch.device = DEVICE_CPU,\n", + " **kwargs: Any,\n", + " ) -> None:\n", + " super().__init__(env_id)\n", + " self._num_envs = num_envs\n", + " # Instantiate the environment object\n", + " self._env = gymnasium.make(id=env_id, autoreset=True, **kwargs)\n", + " # Specify the action space for initialization by the algorithm layer\n", + " self._action_space = self._env.action_space\n", + " # Specify the observation space for initialization by the algorithm layer\n", + " self._observation_space = self._env.observation_space\n", + " # Optional, for GPU acceleration. Default is CPU\n", + " self._device = device # 可选项,使用GPU加速。默认为CPU\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict[str, Any]]:\n", + " # Reset the environment\n", + " obs, info = self._env.reset(seed=seed, options=options)\n", + " # Convert the reset observations to a torch tensor.\n", + " return (\n", + " torch.as_tensor(obs, dtype=torch.float32, device=self._device),\n", + " info,\n", + " )\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> int | None:\n", + " # Return the maximum number of interaction steps per episode in the environment\n", + " return self._env.env.spec.max_episode_steps\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " # Set the environment's random seed for reproducibility\n", + " self.reset(seed=seed) # 设定环境的随机种子以实现可复现性\n", + "\n", + " def render(self) -> Any:\n", + " # Return the image rendered by the environment\n", + " return self._env.render()\n", + "\n", + " def close(self) -> None:\n", + " # Release the environment instance after training ends\n", + " self._env.close()\n", + "\n", + " def step(\n", + " self,\n", + " action: torch.Tensor,\n", + " ) -> tuple[\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " dict[str, Any],\n", + " ]:\n", + " # Read the dynamic information after interacting with the environment\n", + " obs, reward, terminated, truncated, info = self._env.step(\n", + " action.detach().cpu().numpy(),\n", + " )\n", + " # Gymnasium does not explicitly include safety constraints; this is just a placeholder.\n", + " cost = np.zeros_like(reward)\n", + " # Convert dynamic information into torch tensor.\n", + " obs, reward, cost, terminated, truncated = (\n", + " torch.as_tensor(x, dtype=torch.float32, device=self._device)\n", + " for x in (obs, reward, cost, terminated, truncated)\n", + " )\n", + " if 'final_observation' in info:\n", + " info['final_observation'] = np.array(\n", + " [\n", + " array if array is not None else np.zeros(obs.shape[-1])\n", + " for array in info['final_observation']\n", + " ],\n", + " )\n", + " # Convert the last observation recorded in info into a torch tensor.\n", + " info['final_observation'] = torch.as_tensor(\n", + " info['final_observation'],\n", + " dtype=torch.float32,\n", + " device=self._device,\n", + " )\n", + "\n", + " return obs, reward, cost, terminated, truncated, info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Regarding the specific meaning of the above code, we have provided detailed annotation explanations. For more detailed explanations, please refer to [Tutorial 3: Environment Customization from Zero](./3.Environment%20Customization.ipynb). We summarize the key points as follows:\n", + "\n", + "- **Static variables needed for OmniSafe initialization**\n", + "\n", + "| Static Information | Required | Definition | Type | Example |\n", + "|:---:|:---:|:---:|:---:|:---:|\n", + "| `need_auto_reset_wrapper` | Yes | Whether an `AutoReset` Wrapper is needed | `bool` variable | `True` |\n", + "| `need_time_limit_wrapper` | Yes | Whether a `TimeLimit` Wrapper is needed | `bool` variable | `True` |\n", + "| `_action_space` | Yes | Action space | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(2,)` |\n", + "| `_observation_space` | Yes | Observation space | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(3,)` |\n", + "| `max_episode_steps` | Yes | The maximum number of interaction steps per episode in the environment | Function with `@property` decorator, returning a variable of type `int` or `None` | Refer to the code block above |\n", + "| `_num_envs` | No | Number of parallel environments | `int` variable | 5 |\n", + "| `_device` | No | Torch computing device | `torch.device` variable | `DEVICE_CPU` |\n", + "\n", + "- **Dynamic variables required by the environment for OmniSafe**\n", + "\n", + "OmniSafe's agents mainly interact dynamically with the environment through the `reset` and `step` functions. You need to ensure that the return type, number, and order of your customized environment match the examples above, more specifically:\n", + "\n", + "| Dynamic Information | Type | Number | Order |\n", + "|:---:|:---:|:---:|:---:|\n", + "| `step` | `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]` | 6 | `obs`, `reward`, `cost`, `terminated`, `truncated`, `info` |\n", + "| `reset` | `tuple[torch.Tensor, dict[str, Any]]` | 2 | `obs`, `info` |\n", + "\n", + "- **Precautions**\n", + "\n", + "1. Although `_num_envs` and `_device` are not mandatory, please retain the input interface for these two parameters in the `__init__` function.\n", + "2. `_num_envs` is an advanced parameter for instantiating multiple environments for parallel sampling, representing the number of environments instantiated. If your customized environment also supports specifying the parallel number, please specify it through `_num_envs` instead of defining a new interface." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Subsequently, by registering the above environment into OmniSafe with the registration decorator `@env_register`, you can complete the training." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ExampleMuJoCoEnv has not been registered yet\n", + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-14/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-14/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
+       "for Jupyter support\n",
+       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
+       "
\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                 ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ -1616.242431640625    │\n",
+       "│ Metrics/EpCost                 │ 0.0                   │\n",
+       "│ Metrics/EpLen                  │ 200.0                 │\n",
+       "│ Train/Epoch                    │ 0.0                   │\n",
+       "│ Train/Entropy                  │ 1.4185898303985596    │\n",
+       "│ Train/KL                       │ 0.0007516025798395276 │\n",
+       "│ Train/StopIter                 │ 1.0                   │\n",
+       "│ Train/PolicyRatio/Mean         │ 0.9966228604316711    │\n",
+       "│ Train/PolicyRatio/Min          │ 0.9966228604316711    │\n",
+       "│ Train/PolicyRatio/Max          │ 0.9966228604316711    │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0075334208086133    │\n",
+       "│ Train/LR                       │ 0.0                   │\n",
+       "│ Train/PolicyStd                │ 0.9996514320373535    │\n",
+       "│ TotalEnvSteps                  │ 200.0                 │\n",
+       "│ Loss/Loss_pi                   │ 0.08751548826694489   │\n",
+       "│ Loss/Loss_pi/Delta             │ 0.08751548826694489   │\n",
+       "│ Value/Adv                      │ -0.398242324590683    │\n",
+       "│ Loss/Loss_reward_critic        │ 16605.1796875         │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 16605.1796875         │\n",
+       "│ Value/reward                   │ 0.0049050007946789265 │\n",
+       "│ Loss/Loss_cost_critic          │ 0.052194785326719284  │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 0.052194785326719284  │\n",
+       "│ Value/cost                     │ 0.07966174930334091   │\n",
+       "│ Time/Total                     │ 0.21084904670715332   │\n",
+       "│ Time/Rollout                   │ 0.17566156387329102   │\n",
+       "│ Time/Update                    │ 0.03439140319824219   │\n",
+       "│ Time/Epoch                     │ 0.21008920669555664   │\n",
+       "│ Time/FPS                       │ 951.9786987304688     │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                   │\n",
+       "└────────────────────────────────┴───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1616.242431640625 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4185898303985596 │\n", + "│ Train/KL │ 0.0007516025798395276 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Min │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Max │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Std │ 0.0075334208086133 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996514320373535 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.08751548826694489 │\n", + "│ Loss/Loss_pi/Delta │ 0.08751548826694489 │\n", + "│ Value/Adv │ -0.398242324590683 │\n", + "│ Loss/Loss_reward_critic │ 16605.1796875 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16605.1796875 │\n", + "│ Value/reward │ 0.0049050007946789265 │\n", + "│ Loss/Loss_cost_critic │ 0.052194785326719284 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.052194785326719284 │\n", + "│ Value/cost │ 0.07966174930334091 │\n", + "│ Time/Total │ 0.21084904670715332 │\n", + "│ Time/Rollout │ 0.17566156387329102 │\n", + "│ Time/Update │ 0.03439140319824219 │\n", + "│ Time/Epoch │ 0.21008920669555664 │\n", + "│ Time/FPS │ 951.9786987304688 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1616.242431640625, 0.0, 200.0)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister # Avoid the \"environment has been registered\" error when rerunning cells\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + " pass\n", + "\n", + "\n", + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 200,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 200,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Advanced Usage\n", + "In addition to the aforementioned methods, environments from the community can also take advantage of OmniSafe's capabilities for specifying environment-specific parameters and recording information. We will detail the specific operational methods." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Specifying Specific Parameters\n", + "\n", + "Taking `Pendulum-v1` as an example, according to the Gymnasium documentation, a specific parameter `g`, which stands for gravitational acceleration, can be specified when creating this task. Let's first take a look at its default value:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-17/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-17/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "10.0" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister # Avoid the \"environment has been registered\" error when rerunning cells\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + " def __getattr__(self, name: str) -> Any:\n", + " \"\"\"Get the attribute of the environment.\"\"\"\n", + " if name.startswith('_'):\n", + " raise AttributeError(f'attempted to get missing private attribute {name}')\n", + " return getattr(self._env, name)\n", + "\n", + "\n", + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 200,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 200,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.agent._env._env.g" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We implemented a magic function named `__get_attr__` to call and view specific parameters in the currently instantiated environment. In this case, we find that the default value of the gravitational acceleration `g` is 10.0.\n", + "\n", + "By consulting the Gymnasium documentation, this parameter can be specified during the process of creating an environment with the `gymnasium.make` function. Does OmniSafe support the passing of specific parameters for customized environments? The answer is yes:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-20/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-20/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "9.8" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'env_cfgs': {'g': 9.8}})\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.agent._env._env.g" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nice! The value of gravitational acceleration has been changed to 9.8. We just need to operate on `env_cfgs`, specifying the key and value of the parameter to be customized, to achieve the passing of specific parameters for the environment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Information Recording\n", + "\n", + "The `Pendulum-v1` task contains many specific dynamic pieces of information. We will introduce how to record these pieces of information using OmniSafe's `Logger`. Specifically, we will explain using the maximum and cumulative values of the angular velocity `angular_velocity` per episode as examples." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-23/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-23/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                          Value                 ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                   │ -1607.6717529296875   │\n",
+       "│ Metrics/EpCost                  │ 0.0                   │\n",
+       "│ Metrics/EpLen                   │ 200.0                 │\n",
+       "│ Train/Epoch                     │ 0.0                   │\n",
+       "│ Train/Entropy                   │ 1.418560266494751     │\n",
+       "│ Train/KL                        │ 0.0005777678452432156 │\n",
+       "│ Train/StopIter                  │ 1.0                   │\n",
+       "│ Train/PolicyRatio/Mean          │ 0.9981198310852051    │\n",
+       "│ Train/PolicyRatio/Min           │ 0.9981198310852051    │\n",
+       "│ Train/PolicyRatio/Max           │ 0.9981198310852051    │\n",
+       "│ Train/PolicyRatio/Std           │ 0.005412393249571323  │\n",
+       "│ Train/LR                        │ 0.0                   │\n",
+       "│ Train/PolicyStd                 │ 0.9996219277381897    │\n",
+       "│ TotalEnvSteps                   │ 200.0                 │\n",
+       "│ Loss/Loss_pi                    │ 0.09192709624767303   │\n",
+       "│ Loss/Loss_pi/Delta              │ 0.09192709624767303   │\n",
+       "│ Value/Adv                       │ -0.4177907109260559   │\n",
+       "│ Loss/Loss_reward_critic         │ 16393.2265625         │\n",
+       "│ Loss/Loss_reward_critic/Delta   │ 16393.2265625         │\n",
+       "│ Value/reward                    │ 0.00719139538705349   │\n",
+       "│ Loss/Loss_cost_critic           │ 0.05219484493136406   │\n",
+       "│ Loss/Loss_cost_critic/Delta     │ 0.05219484493136406   │\n",
+       "│ Value/cost                      │ 0.07949987053871155   │\n",
+       "│ Time/Total                      │ 0.20513606071472168   │\n",
+       "│ Time/Rollout                    │ 0.17486166954040527   │\n",
+       "│ Time/Update                     │ 0.029330968856811523  │\n",
+       "│ Time/Epoch                      │ 0.20422101020812988   │\n",
+       "│ Time/FPS                        │ 979.3323364257812     │\n",
+       "│ Env/Max_angular_velocity        │ 2.9994523525238037    │\n",
+       "│ Env/Cumulative_angular_velocity │ 1.0643725395202637    │\n",
+       "│ Metrics/LagrangeMultiplier/Mean │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Min  │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Max  │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Std  │ 0.0                   │\n",
+       "└─────────────────────────────────┴───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1607.6717529296875 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.418560266494751 │\n", + "│ Train/KL │ 0.0005777678452432156 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Min │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Max │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Std │ 0.005412393249571323 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996219277381897 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.09192709624767303 │\n", + "│ Loss/Loss_pi/Delta │ 0.09192709624767303 │\n", + "│ Value/Adv │ -0.4177907109260559 │\n", + "│ Loss/Loss_reward_critic │ 16393.2265625 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16393.2265625 │\n", + "│ Value/reward │ 0.00719139538705349 │\n", + "│ Loss/Loss_cost_critic │ 0.05219484493136406 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.05219484493136406 │\n", + "│ Value/cost │ 0.07949987053871155 │\n", + "│ Time/Total │ 0.20513606071472168 │\n", + "│ Time/Rollout │ 0.17486166954040527 │\n", + "│ Time/Update │ 0.029330968856811523 │\n", + "│ Time/Epoch │ 0.20422101020812988 │\n", + "│ Time/FPS │ 979.3323364257812 │\n", + "│ Env/Max_angular_velocity │ 2.9994523525238037 │\n", + "│ Env/Cumulative_angular_velocity │ 1.0643725395202637 │\n", + "│ Metrics/LagrangeMultiplier/Mean │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└─────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1607.6717529296875, 0.0, 200.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from omnisafe.common.logger import Logger\n", + "\n", + "\n", + "@env_register\n", + "@env_unregister # Avoid the \"environment has been registered\" error when rerunning cells\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + "\n", + " def __init__(self, env_id, num_envs, device, **kwargs):\n", + " super().__init__(env_id, num_envs, device, **kwargs)\n", + " self.env_spec_log = {\n", + " 'Env/Max_angular_velocity': 0.0,\n", + " 'Env/Cumulative_angular_velocity': 0.0,\n", + " } # Reiterate and specify in the constructor\n", + "\n", + " def spec_log(self, logger: Logger) -> None:\n", + " for key, value in self.env_spec_log.items():\n", + " logger.store({key: value})\n", + " self.env_spec_log[key] = 0.0\n", + "\n", + " def step(self, action):\n", + " obs, reward, cost, terminated, truncated, info = super().step(action=action)\n", + " angle = obs[-1].item()\n", + " self.env_spec_log['Env/Max_angular_velocity'] = max(\n", + " self.env_spec_log['Env/Max_angular_velocity'], angle\n", + " )\n", + " self.env_spec_log['Env/Cumulative_angular_velocity'] += angle\n", + " return obs, reward, cost, terminated, truncated, info\n", + "\n", + "\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! We successfully recorded the required environment-specific information in the `Logger`. It is worth noting that, in this process, we did not modify any of OmniSafe's source code." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "In this section, using Gymnasium's classic environment `Pendulum-v1`, we introduced the necessary interface adaptation and information provision required to embed an existing community environment into OmniSafe. We hope this tutorial is helpful for embedding your customized environment. If you wish to have your environment supported as one of the official OmniSafe environments, or if you encounter difficulties in customizing environments, you are welcome to communicate with us through the [Issues](https://github.com/PKU-Alignment/omnisafe/issues), [Pull Requests](https://github.com/PKU-Alignment/omnisafe/pulls), and [Discussions](https://github.com/PKU-Alignment/omnisafe/discussions) modules." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/zh-cn/3.Environment Customization from Zero.ipynb b/tutorials/zh-cn/3.Environment Customization from Zero.ipynb new file mode 100644 index 000000000..14f45e272 --- /dev/null +++ b/tutorials/zh-cn/3.Environment Customization from Zero.ipynb @@ -0,0 +1,1440 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OmniSafe Tutorial - Environment Customization From Zero\n", + "\n", + "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n", + "\n", + "Documentation: https://omnisafe.readthedocs.io/en/latest/\n", + "\n", + "Safety-Gymnasium: https://www.safety-gymnasium.com/\n", + "\n", + "[Safety-Gymnasium](https://www.safety-gymnasium.com/) is a highly scalable and customizable Safe Reinforcement Learning library, aiming to deliver a good view of benchmarking Safe Reinforcement Learning (Safe RL) algorithms and a more standardized setting of environments. \n", + "\n", + "## 引言\n", + "\n", + "本节与[Tutorial 4: Environment Customization from Community](./4.Environment%20Customization%20from%20Community.ipynb)共同介绍了如何令定制化环境享受OmniSafe提供的全套训练、记录与保存框架。本节侧重于面向安全强化学习初学者介绍如何从零开始创建环境;而[Tutorial 4: Environment Customization from Community](./4.Environment%20Customization%20from%20Community.ipynb)关注如何将社区已有的环境,例如[Gymnasium](https://github.com/Farama-Foundation/Gymnasium),作出最小适配,以嵌入OmniSafe中。\n", + "\n", + "具体而言,本节提供了一个用于定制化环境的最简单模版。通过该模版,您将了解:\n", + "\n", + "- 如何在OmniSafe中创建并注册一个环境。\n", + "- 如何指定创建环境时的定制化参数。\n", + "- 如何记录环境特定的信息。\n", + "\n", + "## 快速安装" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过pip安装(如果您已经安装,请忽略此段代码)\n", + "%pip install omnisafe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过源代码安装(如果您已经安装,请忽略此段代码)\n", + "## 克隆仓库\n", + "%git clone https://github.com/PKU-Alignment/omnisafe\n", + "%cd omnisafe\n", + "\n", + "## 完成安装\n", + "%pip install -e ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 定制化环境最简模版\n", + "OmniSafe的定制化环境可以仅通过单个文件实现。我们将为您介绍一个最简的定制化环境模版,它将作为您入门的起点。\n", + "\n", + "### 定制化环境设计\n", + "我们将在此细致地介绍一个简易随机环境的设计过程。如果您是强化学习领域的专家或有经验的研究者,可以跳过该模块至[定制化环境嵌入](#定制化环境嵌入)或[Tutorial 4: Environment Customization from Community](./4.Gymnasium%20Customization.ipynb)。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# 导入必要的包\n", + "from __future__ import annotations\n", + "\n", + "import random\n", + "import omnisafe\n", + "from typing import Any, ClassVar\n", + "\n", + "import torch\n", + "from gymnasium import spaces\n", + "\n", + "from omnisafe.envs.core import CMDP, env_register, env_unregister" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# 定义环境类\n", + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "您需要关注上面这段代码的如下细节:\n", + "\n", + "- **任务名称定义** 在 `_support_envs`中提供环境受支持的任务名称。\n", + "- **Wrapper配置** 通过设定 `need_auto_reset_wrapper`和 `need_time_limit_wrapper` 来定义自动重置和限制时间。\n", + "- **并行环境数量** 如果您的环境支持向量化并行,请通过 `_num_envs` 参数进行设定。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成 `__init__`函数定义。此处需要给出环境的动作空间与观测空间。您需要根据您当前在设计的具体任务来定义。例如:\n", + "```python\n", + "if env_id == 'Example-v0':\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "elif env_id == 'Example-v1':\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + "else:\n", + " raise NotImplementedError\n", + "```\n", + "**请注意:** 由于需要为上层模块提供标准的接口,因此在设计环境时请遵循 `self._observation_space` 以及 `self._action_space` 这两个变量名**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成环境初始化相关函数的定义。`reset` 和 `set_seed` 是OmniSafe环境初始化的标准接口。其中 `reset` 重置环境状态与计步器。 `set_seed` 通过设定随机种子确保实验的可复现性。而带有`@property`装饰器的`max_episode_steps`函数用于为`TimeLimit` Wrapper传递需要限制的每幕最大步数。实现参考如下:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成功能性函数的定义。`render` 函数用于渲染环境;`close` 函数用于训练结束后的清理。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10\n", + "\n", + " def render(self) -> Any:\n", + " pass\n", + "\n", + " def close(self) -> None:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成 `step` 函数定义。此处是您定制化环境的核心交互逻辑。您只需按照本例中的数据输入与输出格式进行调整即可。您也可以直接将本例中的随机交互动态更改为您的环境动态。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + " metadata: ClassVar[dict[str, int]] = {}\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10\n", + "\n", + " def render(self) -> Any:\n", + " pass\n", + "\n", + " def close(self) -> None:\n", + " pass\n", + "\n", + " def step(\n", + " self,\n", + " action: torch.Tensor,\n", + " ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:\n", + " self._count += 1\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " reward = 2 * torch.as_tensor(random.random())\n", + " cost = 2 * torch.as_tensor(random.random())\n", + " terminated = torch.as_tensor(random.random() > 0.9)\n", + " truncated = torch.as_tensor(self._count > 10)\n", + " return obs, reward, cost, terminated, truncated, {'final_observation': obs}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接下来,我们试着运行该环境10个时间步,观察交互信息。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------\n", + "obs: tensor([-0.5552, 0.2905, 0.0094])\n", + "reward: 1.6888437271118164\n", + "cost: 1.5159088373184204\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([-0.0635, -0.9966, -0.4681])\n", + "reward: 0.5178334712982178\n", + "cost: 1.0225493907928467\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.4385, 0.0678, -0.3470])\n", + "reward: 1.5675971508026123\n", + "cost: 0.6066254377365112\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.8278, -0.5252, -0.1799])\n", + "reward: 1.1667640209197998\n", + "cost: 1.8162257671356201\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.1086, -0.5711, 0.7751])\n", + "reward: 0.5636757016181946\n", + "cost: 1.511608362197876\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([-0.3585, 0.8011, 0.2172])\n", + "reward: 0.5010126829147339\n", + "cost: 1.8194924592971802\n", + "terminated: True\n", + "truncated: False\n", + "********************\n" + ] + } + ], + "source": [ + "env = ExampleEnv(env_id='Example-v0')\n", + "env.reset(seed=0)\n", + "while True:\n", + " action = env.action_space.sample()\n", + " obs, reward, cost, terminated, truncated, info = env.step(action)\n", + " print('-' * 20)\n", + " print(f'obs: {obs}')\n", + " print(f'reward: {reward}')\n", + " print(f'cost: {cost}')\n", + " print(f'terminated: {terminated}')\n", + " print(f'truncated: {truncated}')\n", + " print('*' * 20)\n", + " if terminated or truncated:\n", + " break\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "恭喜您!已经成功完成了基础的环境定义,接下来,我们将介绍如何将该环境注册入OmniSafe中,并实现环境参数传递、交互信息记录、算法训练以及结果保存等步骤。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 定制化环境嵌入\n", + "\n", + "### 快速训练\n", + "\n", + "得益于OmniSafe精心设计的注册机制,我们只需一个装饰器即可将这个环境注册到OmniSafe的环境列表中。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "@env_register\n", + "class ExampleEnv(ExampleEnv):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "注册同名环境将会报错,这是由于**环境名称冲突**。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "class CustomExampleEnv(ExampleEnv):\n", + " example_configs = 1\n", + "\n", + "\n", + "env = CustomExampleEnv('Example-v0')\n", + "env.example_configs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "这时,您需要先对环境手动取消注册。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "@env_unregister\n", + "class CustomExampleEnv(ExampleEnv):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "之后,您就可以重新注册该环境了。在本教程中,我们会同时嵌套 `env_register` 和 `env_unregister` 装饰器,这是为了避免环境重复注册造成报错,即确保该环境只被注册一次,以便用户在阅读本教程时多次修改与运行代码。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CustomExampleEnv has not been registered yet\n" + ] + }, + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class CustomExampleEnv(ExampleEnv):\n", + " example_configs = 2\n", + "\n", + "\n", + "env = CustomExampleEnv('Example-v0')\n", + "env.example_configs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "随后,您可以使用OmniSafe中的算法来训练这个自定义环境。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Example-v0}/seed-000-2024-04-09-15-04-56/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-04-56/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
+       "for Jupyter support\n",
+       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
+       "
\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 5.625942230224609       │\n",
+       "│ Metrics/EpCost                 │ 6.960921287536621       │\n",
+       "│ Metrics/EpLen                  │ 5.0                     │\n",
+       "│ Train/Epoch                    │ 0.0                     │\n",
+       "│ Train/Entropy                  │ 1.4189385175704956      │\n",
+       "│ Train/KL                       │ 0.00020748490351252258  │\n",
+       "│ Train/StopIter                 │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
+       "│ Train/LR                       │ 0.00019999999494757503  │\n",
+       "│ Train/PolicyStd                │ 1.0                     │\n",
+       "│ TotalEnvSteps                  │ 10.0                    │\n",
+       "│ Loss/Loss_pi                   │ -1.4901161193847656e-08 │\n",
+       "│ Loss/Loss_pi/Delta             │ -1.4901161193847656e-08 │\n",
+       "│ Value/Adv                      │ 1.4901161193847656e-08  │\n",
+       "│ Loss/Loss_reward_critic        │ 10.458966255187988      │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 10.458966255187988      │\n",
+       "│ Value/reward                   │ -0.015489530749619007   │\n",
+       "│ Loss/Loss_cost_critic          │ 19.141571044921875      │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 19.141571044921875      │\n",
+       "│ Value/cost                     │ 0.05426764488220215     │\n",
+       "│ Time/Total                     │ 0.034796953201293945    │\n",
+       "│ Time/Rollout                   │ 0.01762533187866211     │\n",
+       "│ Time/Update                    │ 0.01616811752319336     │\n",
+       "│ Time/Epoch                     │ 0.03383183479309082     │\n",
+       "│ Time/FPS                       │ 295.5858459472656       │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
+       "└────────────────────────────────┴─────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00020748490351252258 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.00019999999494757503 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -1.4901161193847656e-08 │\n", + "│ Loss/Loss_pi/Delta │ -1.4901161193847656e-08 │\n", + "│ Value/Adv │ 1.4901161193847656e-08 │\n", + "│ Loss/Loss_reward_critic │ 10.458966255187988 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.458966255187988 │\n", + "│ Value/reward │ -0.015489530749619007 │\n", + "│ Loss/Loss_cost_critic │ 19.141571044921875 │\n", + "│ Loss/Loss_cost_critic/Delta │ 19.141571044921875 │\n", + "│ Value/cost │ 0.05426764488220215 │\n", + "│ Time/Total │ 0.034796953201293945 │\n", + "│ Time/Rollout │ 0.01762533187866211 │\n", + "│ Time/Update │ 0.01616811752319336 │\n", + "│ Time/Epoch │ 0.03383183479309082 │\n", + "│ Time/FPS │ 295.5858459472656 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 10.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m10.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                  ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 7.8531365394592285     │\n",
+       "│ Metrics/EpCost                 │ 7.931504726409912      │\n",
+       "│ Metrics/EpLen                  │ 6.666666507720947      │\n",
+       "│ Train/Epoch                    │ 1.0                    │\n",
+       "│ Train/Entropy                  │ 1.4192386865615845     │\n",
+       "│ Train/KL                       │ 6.416345422621816e-05  │\n",
+       "│ Train/StopIter                 │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Mean         │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Min          │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Max          │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                    │\n",
+       "│ Train/LR                       │ 9.999999747378752e-05  │\n",
+       "│ Train/PolicyStd                │ 1.0003000497817993     │\n",
+       "│ TotalEnvSteps                  │ 20.0                   │\n",
+       "│ Loss/Loss_pi                   │ -6.258487417198921e-08 │\n",
+       "│ Loss/Loss_pi/Delta             │ -4.768371297814156e-08 │\n",
+       "│ Value/Adv                      │ 1.341104507446289e-07  │\n",
+       "│ Loss/Loss_reward_critic        │ 38.05686950683594      │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 27.59790325164795      │\n",
+       "│ Value/reward                   │ -0.008213319815695286  │\n",
+       "│ Loss/Loss_cost_critic          │ 23.737285614013672     │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 4.595714569091797      │\n",
+       "│ Value/cost                     │ 0.17113244533538818    │\n",
+       "│ Time/Total                     │ 0.0776519775390625     │\n",
+       "│ Time/Rollout                   │ 0.015673398971557617   │\n",
+       "│ Time/Update                    │ 0.011301994323730469   │\n",
+       "│ Time/Epoch                     │ 0.027007579803466797   │\n",
+       "│ Time/FPS                       │ 370.27294921875        │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                    │\n",
+       "└────────────────────────────────┴────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 7.8531365394592285 │\n", + "│ Metrics/EpCost │ 7.931504726409912 │\n", + "│ Metrics/EpLen │ 6.666666507720947 │\n", + "│ Train/Epoch │ 1.0 │\n", + "│ Train/Entropy │ 1.4192386865615845 │\n", + "│ Train/KL │ 6.416345422621816e-05 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 9.999999747378752e-05 │\n", + "│ Train/PolicyStd │ 1.0003000497817993 │\n", + "│ TotalEnvSteps │ 20.0 │\n", + "│ Loss/Loss_pi │ -6.258487417198921e-08 │\n", + "│ Loss/Loss_pi/Delta │ -4.768371297814156e-08 │\n", + "│ Value/Adv │ 1.341104507446289e-07 │\n", + "│ Loss/Loss_reward_critic │ 38.05686950683594 │\n", + "│ Loss/Loss_reward_critic/Delta │ 27.59790325164795 │\n", + "│ Value/reward │ -0.008213319815695286 │\n", + "│ Loss/Loss_cost_critic │ 23.737285614013672 │\n", + "│ Loss/Loss_cost_critic/Delta │ 4.595714569091797 │\n", + "│ Value/cost │ 0.17113244533538818 │\n", + "│ Time/Total │ 0.0776519775390625 │\n", + "│ Time/Rollout │ 0.015673398971557617 │\n", + "│ Time/Update │ 0.011301994323730469 │\n", + "│ Time/Epoch │ 0.027007579803466797 │\n", + "│ Time/FPS │ 370.27294921875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 9.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m9.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 6.297085762023926       │\n",
+       "│ Metrics/EpCost                 │ 6.2187700271606445      │\n",
+       "│ Metrics/EpLen                  │ 5.25                    │\n",
+       "│ Train/Epoch                    │ 2.0                     │\n",
+       "│ Train/Entropy                  │ 1.419387698173523       │\n",
+       "│ Train/KL                       │ 5.490810053743189e-06   │\n",
+       "│ Train/StopIter                 │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
+       "│ Train/LR                       │ 0.0                     │\n",
+       "│ Train/PolicyStd                │ 1.0004494190216064      │\n",
+       "│ TotalEnvSteps                  │ 30.0                    │\n",
+       "│ Loss/Loss_pi                   │ -1.9073486612342094e-07 │\n",
+       "│ Loss/Loss_pi/Delta             │ -1.2814999195143173e-07 │\n",
+       "│ Value/Adv                      │ 1.0728835775353218e-07  │\n",
+       "│ Loss/Loss_reward_critic        │ 34.77037811279297       │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ -3.2864913940429688     │\n",
+       "│ Value/reward                   │ 0.014150517992675304    │\n",
+       "│ Loss/Loss_cost_critic          │ 27.43436050415039       │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 3.6970748901367188      │\n",
+       "│ Value/cost                     │ 0.24021005630493164     │\n",
+       "│ Time/Total                     │ 0.12173724174499512     │\n",
+       "│ Time/Rollout                   │ 0.01879405975341797     │\n",
+       "│ Time/Update                    │ 0.011112689971923828    │\n",
+       "│ Time/Epoch                     │ 0.0299375057220459      │\n",
+       "│ Time/FPS                       │ 334.039794921875        │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
+       "└────────────────────────────────┴─────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 6.297085762023926 │\n", + "│ Metrics/EpCost │ 6.2187700271606445 │\n", + "│ Metrics/EpLen │ 5.25 │\n", + "│ Train/Epoch │ 2.0 │\n", + "│ Train/Entropy │ 1.419387698173523 │\n", + "│ Train/KL │ 5.490810053743189e-06 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0004494190216064 │\n", + "│ TotalEnvSteps │ 30.0 │\n", + "│ Loss/Loss_pi │ -1.9073486612342094e-07 │\n", + "│ Loss/Loss_pi/Delta │ -1.2814999195143173e-07 │\n", + "│ Value/Adv │ 1.0728835775353218e-07 │\n", + "│ Loss/Loss_reward_critic │ 34.77037811279297 │\n", + "│ Loss/Loss_reward_critic/Delta │ -3.2864913940429688 │\n", + "│ Value/reward │ 0.014150517992675304 │\n", + "│ Loss/Loss_cost_critic │ 27.43436050415039 │\n", + "│ Loss/Loss_cost_critic/Delta │ 3.6970748901367188 │\n", + "│ Value/cost │ 0.24021005630493164 │\n", + "│ Time/Total │ 0.12173724174499512 │\n", + "│ Time/Rollout │ 0.01879405975341797 │\n", + "│ Time/Update │ 0.011112689971923828 │\n", + "│ Time/Epoch │ 0.0299375057220459 │\n", + "│ Time/FPS │ 334.039794921875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(6.297085762023926, 6.2187700271606445, 5.25)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 30,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 10,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Example-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "干得不错!我们已经完成了这个定制化环境的嵌入和训练。接下来,我们将进一步研究如何为环境指定超参数。\n", + "\n", + "### 参数设定\n", + "\n", + "我们从一个新的示例环境出发,假设这个环境需要传入一个名为 `num_agents` 的参数。我们将展示如何不修改OmniSafe的代码来完成参数设定。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NewExampleEnv has not been registered yet\n" + ] + } + ], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class NewExampleEnv(ExampleEnv): # 创造一个新环境\n", + " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n", + " num_agents: ClassVar[int] = 1\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n", + " self.num_agents = kwargs.get('num_agents', 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "此时,`num_agents` 参数为预设值:`1`。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_env = NewExampleEnv('NewExample-v0')\n", + "new_env.num_agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "下面我们将展示如何通过 OmniSafe 的接口对该参数进行修改并训练:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-05-09/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-09/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'env_cfgs': {'num_agents': 2}})\n", + "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n", + "agent.agent._env._env.num_agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "非常好!我们将 `num_agents` 设置为了2。这表示我们在未修改代码的情形下成功实现了超参数设定。\n", + "\n", + "### 训练信息记录\n", + "\n", + "在运行训练代码时,您可能已经发现 OmniSafe 通过 `Logger` 记录了训练信息,例如:\n", + "\n", + "```bash\n", + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 2.046875 │\n", + "│ Metrics/EpCost │ 2.89453125 │\n", + "│ Metrics/EpLen │ 3.25 │\n", + "│ Train/Epoch │ 3.0 │\n", + "...\n", + "```\n", + "那么我们可否将环境之中的信息输出到日志中呢?答案是肯定的,而且这个过程同样不需要修改OmniSafe的代码。只需要实现两个标准接口:\n", + "1. 在 `__init__` 函数中,将需要输出的信息添加到`self.env_spec_log`中。\n", + "2. 实例化 `spec_log` 函数,记录所需的信息。\n", + "\n", + "**请注意:** 目前OmniSafe仅支持在每一个epoch结束时记录这些信息,而不支持在每一个step结束时记录。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class NewExampleEnv(ExampleEnv):\n", + " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n", + "\n", + " # 定义需要记录的信息\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n", + " self.env_spec_log = {'Env/Success_counts': 0}\n", + "\n", + " # 通过step函数,与环境进行交互\n", + " def step(self, action):\n", + " obs, reward, cost, terminated, truncated, info = super().step(action)\n", + " success = int(reward > cost)\n", + " self.env_spec_log['Env/Success_counts'] += success\n", + " return obs, reward, cost, terminated, truncated, info\n", + "\n", + " # 在logger中记录环境信息\n", + " def spec_log(self, logger) -> dict[str, Any]:\n", + " logger.store({'Env/Success_counts': self.env_spec_log['Env/Success_counts']})\n", + " self.env_spec_log['Env/Success_counts'] = 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接下来,我们简单训练观察该信息是否被成功记录。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-05-14/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-14/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                  ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ 5.625942230224609      │\n",
+       "│ Metrics/EpCost                 │ 6.960921287536621      │\n",
+       "│ Metrics/EpLen                  │ 5.0                    │\n",
+       "│ Train/Epoch                    │ 0.0                    │\n",
+       "│ Train/Entropy                  │ 1.4189385175704956     │\n",
+       "│ Train/KL                       │ 0.00024281258811242878 │\n",
+       "│ Train/StopIter                 │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Mean         │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Min          │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Max          │ 1.0                    │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0                    │\n",
+       "│ Train/LR                       │ 0.0                    │\n",
+       "│ Train/PolicyStd                │ 1.0                    │\n",
+       "│ TotalEnvSteps                  │ 10.0                   │\n",
+       "│ Loss/Loss_pi                   │ -5.662441182607836e-08 │\n",
+       "│ Loss/Loss_pi/Delta             │ -5.662441182607836e-08 │\n",
+       "│ Value/Adv                      │ 1.2814999195143173e-07 │\n",
+       "│ Loss/Loss_reward_critic        │ 10.477845191955566     │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 10.477845191955566     │\n",
+       "│ Value/reward                   │ -0.0091781010851264    │\n",
+       "│ Loss/Loss_cost_critic          │ 18.525999069213867     │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 18.525999069213867     │\n",
+       "│ Value/cost                     │ 0.14141643047332764    │\n",
+       "│ Time/Total                     │ 0.030597209930419922   │\n",
+       "│ Time/Rollout                   │ 0.017596960067749023   │\n",
+       "│ Time/Update                    │ 0.012219905853271484   │\n",
+       "│ Time/Epoch                     │ 0.02985072135925293    │\n",
+       "│ Time/FPS                       │ 335.00830078125        │\n",
+       "│ Env/Success_counts             │ 1.5                    │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                    │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                    │\n",
+       "└────────────────────────────────┴────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00024281258811242878 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -5.662441182607836e-08 │\n", + "│ Loss/Loss_pi/Delta │ -5.662441182607836e-08 │\n", + "│ Value/Adv │ 1.2814999195143173e-07 │\n", + "│ Loss/Loss_reward_critic │ 10.477845191955566 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.477845191955566 │\n", + "│ Value/reward │ -0.0091781010851264 │\n", + "│ Loss/Loss_cost_critic │ 18.525999069213867 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.525999069213867 │\n", + "│ Value/cost │ 0.14141643047332764 │\n", + "│ Time/Total │ 0.030597209930419922 │\n", + "│ Time/Rollout │ 0.017596960067749023 │\n", + "│ Time/Update │ 0.012219905853271484 │\n", + "│ Time/Epoch │ 0.02985072135925293 │\n", + "│ Time/FPS │ 335.00830078125 │\n", + "│ Env/Success_counts │ 1.5 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(5.625942230224609, 6.960921287536621, 5.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'train_cfgs': {'total_steps': 10}})\n", + "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "漂亮!上述代码将在终端输出了环境特化的信息 `Env/Success_counts`。这一过程并不需要对原代码作出改动。\n", + "\n", + "## 总结\n", + "OmniSafe旨在成为安全强化学习的基础软件。我们将持续完善OmniSafe的环境接口标准,使OmniSafe能够适应各种安全强化学习任务,赋能多元安全场景。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/zh-cn/4.Environment Customization from Community.ipynb b/tutorials/zh-cn/4.Environment Customization from Community.ipynb new file mode 100644 index 000000000..6e05bd036 --- /dev/null +++ b/tutorials/zh-cn/4.Environment Customization from Community.ipynb @@ -0,0 +1,903 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OmniSafe Tutorial - Environment Customization from Community\n", + "\n", + "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n", + "\n", + "Documentation: https://omnisafe.readthedocs.io/en/latest/\n", + "\n", + "Gymnasium: https://github.com/Farama-Foundation/Gymnasium\n", + "\n", + "[Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API.\n", + "\n", + "## 引言\n", + "\n", + "在本节当中,我们将为您介绍如何将一个来自社区的已有环境嵌入OmniSafe中。[Gymnasium](https://github.com/Farama-Foundation/Gymnasium)提供的系列任务已被广泛应用至强化学习中。具体而言,本节将以[Pendulum-v1](https://gymnasium.farama.org/environments/classic_control/pendulum/)为例,展示如何将Gymnasium的任务嵌入OmniSafe。\n", + "\n", + "## 快速安装" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过pip安装(如果您已经安装,请忽略此段代码)\n", + "%pip install omnisafe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过源代码安装(如果您已经安装,请忽略此段代码)\n", + "## 克隆仓库\n", + "%git clone https://github.com/PKU-Alignment/omnisafe\n", + "%cd omnisafe\n", + "\n", + "## 完成安装\n", + "%pip install -e ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gymnasium任务嵌入\n", + "环境嵌入需要的核心是为SafeRL智能体交互与训练提供足够的静态或动态信息,本节将详细介绍嵌入环境所必须定义的变量以及相应规范。我们将首先按照编写代码的逻辑顺序地展示整个嵌入过程,让您有一个初步的了解。然后我们将回顾所有代码,总结并整理您在自定义环境时需要进行的适配。\n", + "\n", + "\n", + "### 快速开始\n", + "首先,导入本教程所需要的所有外部变量。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# 导入必要的包\n", + "from __future__ import annotations\n", + "\n", + "from typing import Any, ClassVar\n", + "import gymnasium\n", + "import torch\n", + "import numpy as np\n", + "import omnisafe\n", + "\n", + "from omnisafe.envs.core import CMDP, env_register, env_unregister\n", + "from omnisafe.typing import DEVICE_CPU" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "随后,创建一个名为`ExampleMuJoCoEnv`的类,它需要继承的父类是`CMDP`。(这是因为我们想把环境的交互形式转换为CMDP的范式,您可以根据需要定义新的抽象类以实现新的范式)。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleMuJoCoEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Pendulum-v1'] # 支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(\n", + " self,\n", + " env_id: str,\n", + " num_envs: int = 1,\n", + " device: torch.device = DEVICE_CPU,\n", + " **kwargs: Any,\n", + " ) -> None:\n", + " super().__init__(env_id)\n", + " self._num_envs = num_envs\n", + " self._env = gymnasium.make(id=env_id, autoreset=True, **kwargs) # 实例化环境对象\n", + " self._action_space = self._env.action_space # 指定动作空间,以供算法层初始化读取\n", + " self._observation_space = self._env.observation_space # 指定观测空间,以供算法层初始化读取\n", + " self._device = device # 可选项,使用GPU加速。默认为CPU\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict[str, Any]]:\n", + " obs, info = self._env.reset(seed=seed, options=options) # 重置环境\n", + " return (\n", + " torch.as_tensor(obs, dtype=torch.float32, device=self._device),\n", + " info,\n", + " ) # 将重置后的观测转换为torch tensor。\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> int | None:\n", + " return self._env.env.spec.max_episode_steps # 返回环境每一幕的最大交互步数\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " self.reset(seed=seed) # 设定环境的随机种子以实现可复现性\n", + "\n", + " def render(self) -> Any:\n", + " return self._env.render() # 返回环境渲染的图像\n", + "\n", + " def close(self) -> None:\n", + " self._env.close() # 训练结束后,释放环境实例\n", + "\n", + " def step(\n", + " self,\n", + " action: torch.Tensor,\n", + " ) -> tuple[\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " dict[str, Any],\n", + " ]:\n", + " obs, reward, terminated, truncated, info = self._env.step(\n", + " action.detach().cpu().numpy(),\n", + " ) # 读取与环境交互后的动态信息\n", + " cost = np.zeros_like(reward) # Gymnasium并显式包含安全约束,此处仅为占位。\n", + " obs, reward, cost, terminated, truncated = (\n", + " torch.as_tensor(x, dtype=torch.float32, device=self._device)\n", + " for x in (obs, reward, cost, terminated, truncated)\n", + " ) # 将动态信息转换为torch tensor。\n", + " if 'final_observation' in info:\n", + " info['final_observation'] = np.array(\n", + " [\n", + " array if array is not None else np.zeros(obs.shape[-1])\n", + " for array in info['final_observation']\n", + " ],\n", + " )\n", + " info['final_observation'] = torch.as_tensor(\n", + " info['final_observation'],\n", + " dtype=torch.float32,\n", + " device=self._device,\n", + " ) # 将info中记录的上一幕final observation转换为torch tensor。\n", + "\n", + " return obs, reward, cost, terminated, truncated, info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "有关上述代码的具体含义,我们已提供了详细的注释说明。更详细的解释可参考[Tutorial 3: Environment Customization from Zero](./3.Environment%20Customization.ipynb)。我们将要点总结如下:\n", + "\n", + "- **OmniSafe初始化需要的静态变量**\n", + "\n", + "| 静态信息 | 必须 | 定义 | 类型 | 例子 |\n", + "|:---:|:---:|:---:|:---:|:---:|\n", + "| `need_auto_reset_wrapper` | 是 | 是否需要 `AutoReset` Wrapper | `bool`变量 | `True` |\n", + "| `need_time_limit_wrapper` | 是 | 是否需要 `TimeLimit` Wrapper | `bool`变量 | `True` |\n", + "| `_action_space` | 是 | 动作空间 | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(2,)` |\n", + "| `_observation_space` | 是 | 观测空间 | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(3,)` |\n", + "| `max_episode_steps` | 是 | 环境每一幕的最大交互步数 | 带有`@property`装饰器的,返回值为`int`或`None`类型变量的函数 | 参考上方代码块 |\n", + "| `_num_envs` | 否 | 并行环境数 | `int`变量 | 5 |\n", + "| `_device` | 否 | torch计算设备 | `torch.device`变量 | `DEVICE_CPU` |\n", + "\n", + "- **OmniSafe需要环境提供的动态变量**\n", + "\n", + "OmniSafe的智能体主要通过`reset`和`step`函数与环境进行动态交互。您需要确保定制化环境的返回值类型、个数与顺序与上述例子一致,更具体地:\n", + "\n", + "| 动态信息 | 类型 | 个数 | 顺序 |\n", + "|:---:|:---:|:---:|:---:|\n", + "| `step` | `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]` | 6 | `obs`, `reward`, `cost`, `terminated`, `truncated`, `info` |\n", + "| `reset` | `tuple[torch.Tensor, dict[str, Any]]` | 2 | `obs`, `info` |\n", + "\n", + "- **注意事项**\n", + "\n", + "1. 尽管`_num_envs`与`_device`并不是必须指定的,但也请您在`__init__`函数中保留这两个参数的输入接口。\n", + "2. `_num_envs`是实例化多个环境并行采样的高级参数,它表示实例化环境的数目。如果您的定制化环境同样支持并行数指定,请通过`_num_envs`指定,而不用再定义一个新的接口。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "随后,将上述环境通过注册装饰器`@env_register`注册入OmniSafe中,即可完成训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ExampleMuJoCoEnv has not been registered yet\n", + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-05-55/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-55/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
+       "for Jupyter support\n",
+       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
+       "
\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                         Value                 ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                  │ -1616.242431640625    │\n",
+       "│ Metrics/EpCost                 │ 0.0                   │\n",
+       "│ Metrics/EpLen                  │ 200.0                 │\n",
+       "│ Train/Epoch                    │ 0.0                   │\n",
+       "│ Train/Entropy                  │ 1.4185898303985596    │\n",
+       "│ Train/KL                       │ 0.0007516025798395276 │\n",
+       "│ Train/StopIter                 │ 1.0                   │\n",
+       "│ Train/PolicyRatio/Mean         │ 0.9966228604316711    │\n",
+       "│ Train/PolicyRatio/Min          │ 0.9966228604316711    │\n",
+       "│ Train/PolicyRatio/Max          │ 0.9966228604316711    │\n",
+       "│ Train/PolicyRatio/Std          │ 0.0075334208086133    │\n",
+       "│ Train/LR                       │ 0.0                   │\n",
+       "│ Train/PolicyStd                │ 0.9996514320373535    │\n",
+       "│ TotalEnvSteps                  │ 200.0                 │\n",
+       "│ Loss/Loss_pi                   │ 0.08751548826694489   │\n",
+       "│ Loss/Loss_pi/Delta             │ 0.08751548826694489   │\n",
+       "│ Value/Adv                      │ -0.398242324590683    │\n",
+       "│ Loss/Loss_reward_critic        │ 16605.1796875         │\n",
+       "│ Loss/Loss_reward_critic/Delta  │ 16605.1796875         │\n",
+       "│ Value/reward                   │ 0.0049050007946789265 │\n",
+       "│ Loss/Loss_cost_critic          │ 0.052194785326719284  │\n",
+       "│ Loss/Loss_cost_critic/Delta    │ 0.052194785326719284  │\n",
+       "│ Value/cost                     │ 0.07966174930334091   │\n",
+       "│ Time/Total                     │ 0.2075355052947998    │\n",
+       "│ Time/Rollout                   │ 0.1734788417816162    │\n",
+       "│ Time/Update                    │ 0.033020973205566406  │\n",
+       "│ Time/Epoch                     │ 0.20653653144836426   │\n",
+       "│ Time/FPS                       │ 968.3539428710938     │\n",
+       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Min │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Max │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Std │ 0.0                   │\n",
+       "└────────────────────────────────┴───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1616.242431640625 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4185898303985596 │\n", + "│ Train/KL │ 0.0007516025798395276 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Min │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Max │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Std │ 0.0075334208086133 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996514320373535 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.08751548826694489 │\n", + "│ Loss/Loss_pi/Delta │ 0.08751548826694489 │\n", + "│ Value/Adv │ -0.398242324590683 │\n", + "│ Loss/Loss_reward_critic │ 16605.1796875 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16605.1796875 │\n", + "│ Value/reward │ 0.0049050007946789265 │\n", + "│ Loss/Loss_cost_critic │ 0.052194785326719284 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.052194785326719284 │\n", + "│ Value/cost │ 0.07966174930334091 │\n", + "│ Time/Total │ 0.2075355052947998 │\n", + "│ Time/Rollout │ 0.1734788417816162 │\n", + "│ Time/Update │ 0.033020973205566406 │\n", + "│ Time/Epoch │ 0.20653653144836426 │\n", + "│ Time/FPS │ 968.3539428710938 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1616.242431640625, 0.0, 200.0)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister # 避免重复运行单元格时产生\"环境已注册\"报错\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + " pass\n", + "\n", + "\n", + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 200,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 200,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 高级使用\n", + "除了上述使用方式外,来自社区的环境还可以享受OmniSafe的环境特定参数指定以及信息记录的特性。我们将详细展示具体操作方式。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 特定参数指定\n", + "\n", + "以`Pendulum-v1`为例,根据Gymnasium的官方文档,创建该任务时可指定一个特定参数为`g`,即重力加速度。我们首先来看看它的默认取值:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-05-58/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-58/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "10.0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister # 避免重复运行单元格时产生\"环境已注册\"报错\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + " def __getattr__(self, name: str) -> Any:\n", + " \"\"\"Get the attribute of the environment.\"\"\"\n", + " if name.startswith('_'):\n", + " raise AttributeError(f'attempted to get missing private attribute {name}')\n", + " return getattr(self._env, name)\n", + "\n", + "\n", + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 200,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 200,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.agent._env._env.g" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们实现了一个名为`__get_attr__`的魔法函数,用于调用并查看当前实例化的环境中的特定参数。在本例中,我们发现重力加速度`g`的默认值是10.0\n", + "\n", + "通过查阅Gymnasium的文档,该参数可以在调用`gymnasium.make`函数创建环境的过程中指定。OmniSafe是否支持定制化环境的特定参数传递呢?答案是肯定的,具体操作也非常简单:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-06-01/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-06-01/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "9.8" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'env_cfgs': {'g': 9.8}})\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.agent._env._env.g" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "非常好!重力加速度取值被我们更改为了9.8。我们只需要对`env_cfgs`进行操作,将需要定制参数的键与值指定,即可实现环境的特定参数传递。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 信息记录\n", + "\n", + "`Pendulum-v1`任务有许多特定的动态信息,我们将为您介绍如何通过OmniSafe的`Logger`记录这些信息。具体而言,我们将以每幕角速度`angular_velocity`的最大值以及累计值为例为您讲解。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-06-03/progress.csv\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-06-03/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;33mSave with config in config.json\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO: Start training\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mINFO: Start training\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Metrics                          Value                 ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ Metrics/EpRet                   │ -1607.6717529296875   │\n",
+       "│ Metrics/EpCost                  │ 0.0                   │\n",
+       "│ Metrics/EpLen                   │ 200.0                 │\n",
+       "│ Train/Epoch                     │ 0.0                   │\n",
+       "│ Train/Entropy                   │ 1.418560266494751     │\n",
+       "│ Train/KL                        │ 0.0005777678452432156 │\n",
+       "│ Train/StopIter                  │ 1.0                   │\n",
+       "│ Train/PolicyRatio/Mean          │ 0.9981198310852051    │\n",
+       "│ Train/PolicyRatio/Min           │ 0.9981198310852051    │\n",
+       "│ Train/PolicyRatio/Max           │ 0.9981198310852051    │\n",
+       "│ Train/PolicyRatio/Std           │ 0.005412393249571323  │\n",
+       "│ Train/LR                        │ 0.0                   │\n",
+       "│ Train/PolicyStd                 │ 0.9996219277381897    │\n",
+       "│ TotalEnvSteps                   │ 200.0                 │\n",
+       "│ Loss/Loss_pi                    │ 0.09192709624767303   │\n",
+       "│ Loss/Loss_pi/Delta              │ 0.09192709624767303   │\n",
+       "│ Value/Adv                       │ -0.4177907109260559   │\n",
+       "│ Loss/Loss_reward_critic         │ 16393.2265625         │\n",
+       "│ Loss/Loss_reward_critic/Delta   │ 16393.2265625         │\n",
+       "│ Value/reward                    │ 0.00719139538705349   │\n",
+       "│ Loss/Loss_cost_critic           │ 0.05219484493136406   │\n",
+       "│ Loss/Loss_cost_critic/Delta     │ 0.05219484493136406   │\n",
+       "│ Value/cost                      │ 0.07949987053871155   │\n",
+       "│ Time/Total                      │ 0.2163846492767334    │\n",
+       "│ Time/Rollout                    │ 0.18010711669921875   │\n",
+       "│ Time/Update                     │ 0.03433847427368164   │\n",
+       "│ Time/Epoch                      │ 0.21448636054992676   │\n",
+       "│ Time/FPS                        │ 932.4664306640625     │\n",
+       "│ Env/Max_angular_velocity        │ 2.9994523525238037    │\n",
+       "│ Env/Cumulative_angular_velocity │ 1.0643725395202637    │\n",
+       "│ Metrics/LagrangeMultiplier/Mean │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Min  │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Max  │ 0.0                   │\n",
+       "│ Metrics/LagrangeMultiplier/Std  │ 0.0                   │\n",
+       "└─────────────────────────────────┴───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1607.6717529296875 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.418560266494751 │\n", + "│ Train/KL │ 0.0005777678452432156 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Min │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Max │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Std │ 0.005412393249571323 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996219277381897 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.09192709624767303 │\n", + "│ Loss/Loss_pi/Delta │ 0.09192709624767303 │\n", + "│ Value/Adv │ -0.4177907109260559 │\n", + "│ Loss/Loss_reward_critic │ 16393.2265625 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16393.2265625 │\n", + "│ Value/reward │ 0.00719139538705349 │\n", + "│ Loss/Loss_cost_critic │ 0.05219484493136406 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.05219484493136406 │\n", + "│ Value/cost │ 0.07949987053871155 │\n", + "│ Time/Total │ 0.2163846492767334 │\n", + "│ Time/Rollout │ 0.18010711669921875 │\n", + "│ Time/Update │ 0.03433847427368164 │\n", + "│ Time/Epoch │ 0.21448636054992676 │\n", + "│ Time/FPS │ 932.4664306640625 │\n", + "│ Env/Max_angular_velocity │ 2.9994523525238037 │\n", + "│ Env/Cumulative_angular_velocity │ 1.0643725395202637 │\n", + "│ Metrics/LagrangeMultiplier/Mean │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└─────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1607.6717529296875, 0.0, 200.0)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from omnisafe.common.logger import Logger\n", + "\n", + "\n", + "@env_register\n", + "@env_unregister # 避免重复运行单元格时产生\"环境已注册\"报错\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + "\n", + " def __init__(self, env_id, num_envs, device, **kwargs):\n", + " super().__init__(env_id, num_envs, device, **kwargs)\n", + " self.env_spec_log = {\n", + " 'Env/Max_angular_velocity': 0.0,\n", + " 'Env/Cumulative_angular_velocity': 0.0,\n", + " } # 在构造函数中重申并指定\n", + "\n", + " def spec_log(self, logger: Logger) -> None:\n", + " for key, value in self.env_spec_log.items():\n", + " logger.store({key: value})\n", + " self.env_spec_log[key] = 0.0\n", + "\n", + " def step(self, action):\n", + " obs, reward, cost, terminated, truncated, info = super().step(action=action)\n", + " angle = obs[-1].item()\n", + " self.env_spec_log['Env/Max_angular_velocity'] = max(\n", + " self.env_spec_log['Env/Max_angular_velocity'], angle\n", + " )\n", + " self.env_spec_log['Env/Cumulative_angular_velocity'] += angle\n", + " return obs, reward, cost, terminated, truncated, info\n", + "\n", + "\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "太好了!我们成功地在`Logger`中记录了需要的环境特定信息。值得注意的是,在这一过程中我们并没有修改OmniSafe的任何源代码。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 总结\n", + "我们在本节使用了Gymnasium的经典环境`Pendulum-v1`,为您介绍了将一个社区已有的环境嵌入OmniSafe中所需的必要接口适配与信息提供。我们希望这个教程对您的定制化环境嵌入过程有帮助。如果您想将自己的环境作为OmniSafe官方支持的环境之一,或者在定制化环境中遇到了困难,欢迎在[Issues](https://github.com/PKU-Alignment/omnisafe/issues),[Pull Requests](https://github.com/PKU-Alignment/omnisafe/pulls)与[Discussions](https://github.com/PKU-Alignment/omnisafe/discussions)模块与我们沟通。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}