From efe7d9cea96e743511e11b2056cbf05b5932364e Mon Sep 17 00:00:00 2001
From: Jiayi Zhou <108712610+Gaiejj@users.noreply.github.com>
Date: Tue, 9 Apr 2024 15:24:37 +0800
Subject: [PATCH] feat: support interface of environment customization (#310)
---
.pre-commit-config.yaml | 10 +-
docs/requirements.txt | 1 +
docs/source/envs/custom.rst | 22 +
docs/source/index.rst | 2 +
docs/source/spelling_wordlist.txt | 1 +
docs/source/start/env.rst | 73 +
examples/train_from_custom_env.py | 90 ++
omnisafe/adapter/offpolicy_adapter.py | 6 +-
omnisafe/adapter/online_adapter.py | 37 +-
omnisafe/adapter/onpolicy_adapter.py | 2 +
omnisafe/algorithms/model_based/base/loop.py | 2 +-
omnisafe/algorithms/model_based/base/pets.py | 4 +-
omnisafe/algorithms/off_policy/ddpg.py | 4 +
.../on_policy/base/policy_gradient.py | 17 +-
omnisafe/configs/on-policy/PPOLag.yaml | 2 +
omnisafe/envs/__init__.py | 1 +
omnisafe/envs/core.py | 36 +-
omnisafe/envs/custom_env.py | 199 +++
omnisafe/envs/mujoco_env.py | 17 +-
omnisafe/envs/safety_gymnasium_env.py | 18 +-
omnisafe/envs/safety_gymnasium_modelbased.py | 17 +-
omnisafe/envs/wrapper.py | 1 +
omnisafe/evaluator.py | 19 +-
omnisafe/models/actor_critic/actor_critic.py | 2 -
.../models/actor_critic/actor_q_critic.py | 2 -
omnisafe/utils/config.py | 4 +-
omnisafe/utils/tools.py | 4 +-
tests/distribution_train.py | 2 +-
.../{Simple-v0.npz => Test-v0.npz} | Bin
tests/simple_env.py | 14 +-
tests/test_env.py | 7 +-
tests/test_policy.py | 34 +-
tests/test_registry.py | 19 +
....Environment Customization from Zero.ipynb | 1439 ++++++++++++++++
...ronment Customization from Community.ipynb | 916 +++++++++++
....Environment Customization from Zero.ipynb | 1440 +++++++++++++++++
...ronment Customization from Community.ipynb | 903 +++++++++++
37 files changed, 5253 insertions(+), 114 deletions(-)
create mode 100644 docs/source/envs/custom.rst
create mode 100644 docs/source/start/env.rst
create mode 100644 examples/train_from_custom_env.py
create mode 100644 omnisafe/envs/custom_env.py
rename tests/saved_source/{Simple-v0.npz => Test-v0.npz} (100%)
create mode 100644 tutorials/English/3.Environment Customization from Zero.ipynb
create mode 100644 tutorials/English/4.Environment Customization from Community.ipynb
create mode 100644 tutorials/zh-cn/3.Environment Customization from Zero.ipynb
create mode 100644 tutorials/zh-cn/4.Environment Customization from Community.ipynb
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c8376e71f..0a80ef22d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -29,25 +29,25 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/charliermarsh/ruff-pre-commit
- rev: v0.0.292
+ rev: v0.3.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/PyCQA/isort
- rev: 5.12.0
+ rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/psf/black
- rev: 23.9.1
+ rev: 24.3.0
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
- rev: v3.15.0
+ rev: v3.15.2
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
- repo: https://github.com/pycqa/flake8
- rev: 6.1.0
+ rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 55f237969..4ee781c54 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -8,3 +8,4 @@ sphinx-autoapi
sphinx-autobuild
sphinx-autodoc-typehints
furo
+sphinxcontrib-spelling
diff --git a/docs/source/envs/custom.rst b/docs/source/envs/custom.rst
new file mode 100644
index 000000000..85e33a1a7
--- /dev/null
+++ b/docs/source/envs/custom.rst
@@ -0,0 +1,22 @@
+OmniSafe Customization Interface of Environments
+================================================
+
+.. currentmodule:: omnisafe.envs.custom_env
+
+.. autosummary::
+
+ CustomEnv
+
+CustomEnv
+---------
+
+.. card::
+ :class-header: sd-bg-success sd-text-white
+ :class-card: sd-outline-success sd-rounded-1
+
+ Documentation
+ ^^^
+
+ .. autoclass:: CustomEnv
+ :members:
+ :private-members:
diff --git a/docs/source/index.rst b/docs/source/index.rst
index e759bebee..792f62052 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -366,6 +366,7 @@ this project, don't hesitate to ask your question on `the GitHub issue page
Logging data to ./runs/PPOLag-{Example-v0}/seed-000-2024-04-09-15-08-37/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-08-37/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n", + "\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.0002234023268101737 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.00019999999494757503 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ 7.748603536583687e-08 │\n", + "│ Loss/Loss_pi/Delta │ 7.748603536583687e-08 │\n", + "│ Value/Adv │ -1.7881394143159923e-08 │\n", + "│ Loss/Loss_reward_critic │ 10.457597732543945 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.457597732543945 │\n", + "│ Value/reward │ -0.012156231328845024 │\n", + "│ Loss/Loss_cost_critic │ 18.316673278808594 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.316673278808594 │\n", + "│ Value/cost │ 0.1599183827638626 │\n", + "│ Time/Total │ 0.03895211219787598 │\n", + "│ Time/Rollout │ 0.021677017211914062 │\n", + "│ Time/Update │ 0.01619410514831543 │\n", + "│ Time/Epoch │ 0.0379033088684082 │\n", + "│ Time/FPS │ 263.8358459472656 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.0002234023268101737 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.00019999999494757503 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ 7.748603536583687e-08 │\n", + "│ Loss/Loss_pi/Delta │ 7.748603536583687e-08 │\n", + "│ Value/Adv │ -1.7881394143159923e-08 │\n", + "│ Loss/Loss_reward_critic │ 10.457597732543945 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.457597732543945 │\n", + "│ Value/reward │ -0.012156231328845024 │\n", + "│ Loss/Loss_cost_critic │ 18.316673278808594 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.316673278808594 │\n", + "│ Value/cost │ 0.1599183827638626 │\n", + "│ Time/Total │ 0.03895211219787598 │\n", + "│ Time/Rollout │ 0.021677017211914062 │\n", + "│ Time/Update │ 0.01619410514831543 │\n", + "│ Time/Epoch │ 0.0379033088684082 │\n", + "│ Time/FPS │ 263.8358459472656 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 10.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m10.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 7.8531365394592285 │\n", + "│ Metrics/EpCost │ 7.931504726409912 │\n", + "│ Metrics/EpLen │ 6.666666507720947 │\n", + "│ Train/Epoch │ 1.0 │\n", + "│ Train/Entropy │ 1.4192386865615845 │\n", + "│ Train/KL │ 8.405959670199081e-05 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 9.999999747378752e-05 │\n", + "│ Train/PolicyStd │ 1.0003000497817993 │\n", + "│ TotalEnvSteps │ 20.0 │\n", + "│ Loss/Loss_pi │ -8.940696716308594e-08 │\n", + "│ Loss/Loss_pi/Delta │ -1.668930025289228e-07 │\n", + "│ Value/Adv │ 8.940696716308594e-08 │\n", + "│ Loss/Loss_reward_critic │ 37.962928771972656 │\n", + "│ Loss/Loss_reward_critic/Delta │ 27.50533103942871 │\n", + "│ Value/reward │ -0.00784378219395876 │\n", + "│ Loss/Loss_cost_critic │ 25.662063598632812 │\n", + "│ Loss/Loss_cost_critic/Delta │ 7.345390319824219 │\n", + "│ Value/cost │ 0.11082335561513901 │\n", + "│ Time/Total │ 0.08216094970703125 │\n", + "│ Time/Rollout │ 0.01664590835571289 │\n", + "│ Time/Update │ 0.013554811477661133 │\n", + "│ Time/Epoch │ 0.03022909164428711 │\n", + "│ Time/FPS │ 330.8123779296875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 7.8531365394592285 │\n", + "│ Metrics/EpCost │ 7.931504726409912 │\n", + "│ Metrics/EpLen │ 6.666666507720947 │\n", + "│ Train/Epoch │ 1.0 │\n", + "│ Train/Entropy │ 1.4192386865615845 │\n", + "│ Train/KL │ 8.405959670199081e-05 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 9.999999747378752e-05 │\n", + "│ Train/PolicyStd │ 1.0003000497817993 │\n", + "│ TotalEnvSteps │ 20.0 │\n", + "│ Loss/Loss_pi │ -8.940696716308594e-08 │\n", + "│ Loss/Loss_pi/Delta │ -1.668930025289228e-07 │\n", + "│ Value/Adv │ 8.940696716308594e-08 │\n", + "│ Loss/Loss_reward_critic │ 37.962928771972656 │\n", + "│ Loss/Loss_reward_critic/Delta │ 27.50533103942871 │\n", + "│ Value/reward │ -0.00784378219395876 │\n", + "│ Loss/Loss_cost_critic │ 25.662063598632812 │\n", + "│ Loss/Loss_cost_critic/Delta │ 7.345390319824219 │\n", + "│ Value/cost │ 0.11082335561513901 │\n", + "│ Time/Total │ 0.08216094970703125 │\n", + "│ Time/Rollout │ 0.01664590835571289 │\n", + "│ Time/Update │ 0.013554811477661133 │\n", + "│ Time/Epoch │ 0.03022909164428711 │\n", + "│ Time/FPS │ 330.8123779296875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 9.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m9.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 6.297085762023926 │\n", + "│ Metrics/EpCost │ 6.2187700271606445 │\n", + "│ Metrics/EpLen │ 5.25 │\n", + "│ Train/Epoch │ 2.0 │\n", + "│ Train/Entropy │ 1.419387936592102 │\n", + "│ Train/KL │ 6.185231995914364e-06 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Min │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Max │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0004496574401855 │\n", + "│ TotalEnvSteps │ 30.0 │\n", + "│ Loss/Loss_pi │ 7.152557657263969e-08 │\n", + "│ Loss/Loss_pi/Delta │ 1.6093254373572563e-07 │\n", + "│ Value/Adv │ -1.4305115314527939e-07 │\n", + "│ Loss/Loss_reward_critic │ 34.879573822021484 │\n", + "│ Loss/Loss_reward_critic/Delta │ -3.083354949951172 │\n", + "│ Value/reward │ 0.020589731633663177 │\n", + "│ Loss/Loss_cost_critic │ 27.62775230407715 │\n", + "│ Loss/Loss_cost_critic/Delta │ 1.965688705444336 │\n", + "│ Value/cost │ 0.13300421833992004 │\n", + "│ Time/Total │ 0.12445831298828125 │\n", + "│ Time/Rollout │ 0.0154266357421875 │\n", + "│ Time/Update │ 0.009746313095092773 │\n", + "│ Time/Epoch │ 0.02520155906677246 │\n", + "│ Time/FPS │ 396.81585693359375 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 6.297085762023926 │\n", + "│ Metrics/EpCost │ 6.2187700271606445 │\n", + "│ Metrics/EpLen │ 5.25 │\n", + "│ Train/Epoch │ 2.0 │\n", + "│ Train/Entropy │ 1.419387936592102 │\n", + "│ Train/KL │ 6.185231995914364e-06 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Min │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Max │ 0.9999998211860657 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0004496574401855 │\n", + "│ TotalEnvSteps │ 30.0 │\n", + "│ Loss/Loss_pi │ 7.152557657263969e-08 │\n", + "│ Loss/Loss_pi/Delta │ 1.6093254373572563e-07 │\n", + "│ Value/Adv │ -1.4305115314527939e-07 │\n", + "│ Loss/Loss_reward_critic │ 34.879573822021484 │\n", + "│ Loss/Loss_reward_critic/Delta │ -3.083354949951172 │\n", + "│ Value/reward │ 0.020589731633663177 │\n", + "│ Loss/Loss_cost_critic │ 27.62775230407715 │\n", + "│ Loss/Loss_cost_critic/Delta │ 1.965688705444336 │\n", + "│ Value/cost │ 0.13300421833992004 │\n", + "│ Time/Total │ 0.12445831298828125 │\n", + "│ Time/Rollout │ 0.0154266357421875 │\n", + "│ Time/Update │ 0.009746313095092773 │\n", + "│ Time/Epoch │ 0.02520155906677246 │\n", + "│ Time/FPS │ 396.81585693359375 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(6.297085762023926, 6.2187700271606445, 5.25)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 30,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 10,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Example-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Well done! We have completed the embedding and training of this customized environment. Next, we will further explore how to specify hyperparameters for the environment.\n", + "\n", + "### Parameter Setting\n", + "\n", + "Starting with a new example environment, assume this environment requires a parameter named `num_agents`. We will show how to complete the parameter setting without modifying OmniSafe's code." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NewExampleEnv has not been registered yet\n" + ] + } + ], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class NewExampleEnv(ExampleEnv): # make a new environment\n", + " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n", + " num_agents: ClassVar[int] = 1\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n", + " self.num_agents = kwargs.get('num_agents', 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the `num_agents` parameter is set to a default value: `1`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_env = NewExampleEnv('NewExample-v0')\n", + "new_env.num_agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below we will show how to modify this parameter through OmniSafe's interface and train:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-08-46/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-08-46/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "2"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "custom_cfgs.update({'env_cfgs': {'num_agents': 2}})\n",
+ "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n",
+ "agent.agent._env._env.num_agents"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Excellent! We have set `num_agents` to 2. This means we have successfully implemented hyperparameter setting without modifying the code.\n",
+ "\n",
+ "### Training Information Recording\n",
+ "\n",
+ "While running the training code, you may have noticed that OmniSafe records training information through `Logger`, for example:\n",
+ "\n",
+ "```bash\n",
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Metrics ┃ Value ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ Metrics/EpRet │ 2.046875 │\n",
+ "│ Metrics/EpCost │ 2.89453125 │\n",
+ "│ Metrics/EpLen │ 3.25 │\n",
+ "│ Train/Epoch │ 3.0 │\n",
+ "...\n",
+ "```\n",
+ "So, can we output information from the environment into the log? The answer is yes, and this process also does not require modifying OmniSafe's code. You only need to implement two standard interfaces:\n",
+ "1. In the `__init__` function, add the information you want to output to `self.env_spec_log`.\n",
+ "2. Instantiate the `spec_log` function to record the required information.\n",
+ "\n",
+ "**Please note:** Currently, OmniSafe only supports recording this information at the end of each epoch, not after each step."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@env_register\n",
+ "@env_unregister\n",
+ "class NewExampleEnv(ExampleEnv):\n",
+ " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n",
+ "\n",
+ " # define what to log\n",
+ " def __init__(self, env_id: str, **kwargs) -> None:\n",
+ " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n",
+ " self.env_spec_log = {'Env/Success_counts': 0}\n",
+ "\n",
+ " # interact with the environment and log\n",
+ " def step(self, action):\n",
+ " obs, reward, cost, terminated, truncated, info = super().step(action)\n",
+ " success = int(reward > cost)\n",
+ " self.env_spec_log['Env/Success_counts'] += success\n",
+ " return obs, reward, cost, terminated, truncated, info\n",
+ "\n",
+ " # write to logger\n",
+ " def spec_log(self, logger) -> dict[str, Any]:\n",
+ " logger.store({'Env/Success_counts': self.env_spec_log['Env/Success_counts']})\n",
+ " self.env_spec_log['Env/Success_counts'] = 0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will briefly train and observe whether this information has been successfully recorded."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-08-52/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-08-52/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00026566203450784087 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -2.9802322387695312e-08 │\n", + "│ Loss/Loss_pi/Delta │ -2.9802322387695312e-08 │\n", + "│ Value/Adv │ 5.9604645663569045e-09 │\n", + "│ Loss/Loss_reward_critic │ 10.46424674987793 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.46424674987793 │\n", + "│ Value/reward │ -0.017885426059365273 │\n", + "│ Loss/Loss_cost_critic │ 18.490144729614258 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.490144729614258 │\n", + "│ Value/cost │ 0.13730722665786743 │\n", + "│ Time/Total │ 0.0326535701751709 │\n", + "│ Time/Rollout │ 0.019308805465698242 │\n", + "│ Time/Update │ 0.012392044067382812 │\n", + "│ Time/Epoch │ 0.03173708915710449 │\n", + "│ Time/FPS │ 315.0982360839844 │\n", + "│ Env/Success_counts │ 1.5 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00026566203450784087 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -2.9802322387695312e-08 │\n", + "│ Loss/Loss_pi/Delta │ -2.9802322387695312e-08 │\n", + "│ Value/Adv │ 5.9604645663569045e-09 │\n", + "│ Loss/Loss_reward_critic │ 10.46424674987793 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.46424674987793 │\n", + "│ Value/reward │ -0.017885426059365273 │\n", + "│ Loss/Loss_cost_critic │ 18.490144729614258 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.490144729614258 │\n", + "│ Value/cost │ 0.13730722665786743 │\n", + "│ Time/Total │ 0.0326535701751709 │\n", + "│ Time/Rollout │ 0.019308805465698242 │\n", + "│ Time/Update │ 0.012392044067382812 │\n", + "│ Time/Epoch │ 0.03173708915710449 │\n", + "│ Time/FPS │ 315.0982360839844 │\n", + "│ Env/Success_counts │ 1.5 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(5.625942230224609, 6.960921287536621, 5.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'train_cfgs': {'total_steps': 10}})\n", + "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nice! The above code has outputted the environment-specific information `Env/Success_counts` to the terminal. This process does not require any modifications to the original code.\n", + "\n", + "## Summary\n", + "OmniSafe aims to become the foundational software for safe reinforcement learning. We will continue to refine the environmental interface standards of OmniSafe, enabling it to adapt to various safe reinforcement learning tasks and empower diverse safety scenarios." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/English/4.Environment Customization from Community.ipynb b/tutorials/English/4.Environment Customization from Community.ipynb new file mode 100644 index 000000000..d5123f7c9 --- /dev/null +++ b/tutorials/English/4.Environment Customization from Community.ipynb @@ -0,0 +1,916 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OmniSafe Tutorial - Environment Customization from Community\n", + "\n", + "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n", + "\n", + "Documentation: https://omnisafe.readthedocs.io/en/latest/\n", + "\n", + "Gymnasium: https://github.com/Farama-Foundation/Gymnasium\n", + "\n", + "[Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API.\n", + "\n", + "## Introduction\n", + "\n", + "In this section, we will introduce how to embed an existing environment from the community into OmniSafe. The series of tasks provided by [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) have been widely applied in reinforcement learning. Specifically, this section will use [Pendulum-v1](https://gymnasium.farama.org/environments/classic_control/pendulum/) as an example to show how to embed Gymnasium's tasks into OmniSafe.\n", + "\n", + "## Quick Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install via pip (ignore it if you have already installed).\n", + "%pip install omnisafe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install from source (ignore it if you have already installed).\n", + "## clone the repo\n", + "%git clone https://github.com/PKU-Alignment/omnisafe\n", + "%cd omnisafe\n", + "\n", + "## install it\n", + "%pip install -e ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gymnasium Task Embedding\n", + "The core part for environment embedding is to provide sufficient static or dynamic information for SafeRL agent interaction and training. This section will detail the variables that must be defined for embedding environments and the corresponding standards. We will first present the entire embedding process in the order of code organization, giving a preliminary understanding. Then, we will review all the codes, summarize, and organize the adaptations you need to make when customizing your environment.\n", + "\n", + "### Quick Start\n", + "First, import all external variables required for this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# import all we need\n", + "from __future__ import annotations\n", + "\n", + "from typing import Any, ClassVar\n", + "import gymnasium\n", + "import torch\n", + "import numpy as np\n", + "import omnisafe\n", + "\n", + "from omnisafe.envs.core import CMDP, env_register, env_unregister\n", + "from omnisafe.typing import DEVICE_CPU" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, create a class named `ExampleMuJoCoEnv`, which needs to inherit from `CMDP`. (This is because we want to transform the environment's interaction form into the CMDP paradigm. You can define new abstract classes as needed to implement new paradigms)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleMuJoCoEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Pendulum-v1'] # Supported task names\n", + "\n", + " need_auto_reset_wrapper = True # Whether `AutoReset` Wrapper is needed\n", + " need_time_limit_wrapper = True # Whether `TimeLimit` Wrapper is needed\n", + "\n", + " def __init__(\n", + " self,\n", + " env_id: str,\n", + " num_envs: int = 1,\n", + " device: torch.device = DEVICE_CPU,\n", + " **kwargs: Any,\n", + " ) -> None:\n", + " super().__init__(env_id)\n", + " self._num_envs = num_envs\n", + " # Instantiate the environment object\n", + " self._env = gymnasium.make(id=env_id, autoreset=True, **kwargs)\n", + " # Specify the action space for initialization by the algorithm layer\n", + " self._action_space = self._env.action_space\n", + " # Specify the observation space for initialization by the algorithm layer\n", + " self._observation_space = self._env.observation_space\n", + " # Optional, for GPU acceleration. Default is CPU\n", + " self._device = device # 可选项,使用GPU加速。默认为CPU\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict[str, Any]]:\n", + " # Reset the environment\n", + " obs, info = self._env.reset(seed=seed, options=options)\n", + " # Convert the reset observations to a torch tensor.\n", + " return (\n", + " torch.as_tensor(obs, dtype=torch.float32, device=self._device),\n", + " info,\n", + " )\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> int | None:\n", + " # Return the maximum number of interaction steps per episode in the environment\n", + " return self._env.env.spec.max_episode_steps\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " # Set the environment's random seed for reproducibility\n", + " self.reset(seed=seed) # 设定环境的随机种子以实现可复现性\n", + "\n", + " def render(self) -> Any:\n", + " # Return the image rendered by the environment\n", + " return self._env.render()\n", + "\n", + " def close(self) -> None:\n", + " # Release the environment instance after training ends\n", + " self._env.close()\n", + "\n", + " def step(\n", + " self,\n", + " action: torch.Tensor,\n", + " ) -> tuple[\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " dict[str, Any],\n", + " ]:\n", + " # Read the dynamic information after interacting with the environment\n", + " obs, reward, terminated, truncated, info = self._env.step(\n", + " action.detach().cpu().numpy(),\n", + " )\n", + " # Gymnasium does not explicitly include safety constraints; this is just a placeholder.\n", + " cost = np.zeros_like(reward)\n", + " # Convert dynamic information into torch tensor.\n", + " obs, reward, cost, terminated, truncated = (\n", + " torch.as_tensor(x, dtype=torch.float32, device=self._device)\n", + " for x in (obs, reward, cost, terminated, truncated)\n", + " )\n", + " if 'final_observation' in info:\n", + " info['final_observation'] = np.array(\n", + " [\n", + " array if array is not None else np.zeros(obs.shape[-1])\n", + " for array in info['final_observation']\n", + " ],\n", + " )\n", + " # Convert the last observation recorded in info into a torch tensor.\n", + " info['final_observation'] = torch.as_tensor(\n", + " info['final_observation'],\n", + " dtype=torch.float32,\n", + " device=self._device,\n", + " )\n", + "\n", + " return obs, reward, cost, terminated, truncated, info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Regarding the specific meaning of the above code, we have provided detailed annotation explanations. For more detailed explanations, please refer to [Tutorial 3: Environment Customization from Zero](./3.Environment%20Customization.ipynb). We summarize the key points as follows:\n", + "\n", + "- **Static variables needed for OmniSafe initialization**\n", + "\n", + "| Static Information | Required | Definition | Type | Example |\n", + "|:---:|:---:|:---:|:---:|:---:|\n", + "| `need_auto_reset_wrapper` | Yes | Whether an `AutoReset` Wrapper is needed | `bool` variable | `True` |\n", + "| `need_time_limit_wrapper` | Yes | Whether a `TimeLimit` Wrapper is needed | `bool` variable | `True` |\n", + "| `_action_space` | Yes | Action space | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(2,)` |\n", + "| `_observation_space` | Yes | Observation space | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(3,)` |\n", + "| `max_episode_steps` | Yes | The maximum number of interaction steps per episode in the environment | Function with `@property` decorator, returning a variable of type `int` or `None` | Refer to the code block above |\n", + "| `_num_envs` | No | Number of parallel environments | `int` variable | 5 |\n", + "| `_device` | No | Torch computing device | `torch.device` variable | `DEVICE_CPU` |\n", + "\n", + "- **Dynamic variables required by the environment for OmniSafe**\n", + "\n", + "OmniSafe's agents mainly interact dynamically with the environment through the `reset` and `step` functions. You need to ensure that the return type, number, and order of your customized environment match the examples above, more specifically:\n", + "\n", + "| Dynamic Information | Type | Number | Order |\n", + "|:---:|:---:|:---:|:---:|\n", + "| `step` | `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]` | 6 | `obs`, `reward`, `cost`, `terminated`, `truncated`, `info` |\n", + "| `reset` | `tuple[torch.Tensor, dict[str, Any]]` | 2 | `obs`, `info` |\n", + "\n", + "- **Precautions**\n", + "\n", + "1. Although `_num_envs` and `_device` are not mandatory, please retain the input interface for these two parameters in the `__init__` function.\n", + "2. `_num_envs` is an advanced parameter for instantiating multiple environments for parallel sampling, representing the number of environments instantiated. If your customized environment also supports specifying the parallel number, please specify it through `_num_envs` instead of defining a new interface." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Subsequently, by registering the above environment into OmniSafe with the registration decorator `@env_register`, you can complete the training." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ExampleMuJoCoEnv has not been registered yet\n", + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-14/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-14/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n", + "\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1616.242431640625 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4185898303985596 │\n", + "│ Train/KL │ 0.0007516025798395276 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Min │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Max │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Std │ 0.0075334208086133 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996514320373535 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.08751548826694489 │\n", + "│ Loss/Loss_pi/Delta │ 0.08751548826694489 │\n", + "│ Value/Adv │ -0.398242324590683 │\n", + "│ Loss/Loss_reward_critic │ 16605.1796875 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16605.1796875 │\n", + "│ Value/reward │ 0.0049050007946789265 │\n", + "│ Loss/Loss_cost_critic │ 0.052194785326719284 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.052194785326719284 │\n", + "│ Value/cost │ 0.07966174930334091 │\n", + "│ Time/Total │ 0.21084904670715332 │\n", + "│ Time/Rollout │ 0.17566156387329102 │\n", + "│ Time/Update │ 0.03439140319824219 │\n", + "│ Time/Epoch │ 0.21008920669555664 │\n", + "│ Time/FPS │ 951.9786987304688 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴───────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1616.242431640625 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4185898303985596 │\n", + "│ Train/KL │ 0.0007516025798395276 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Min │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Max │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Std │ 0.0075334208086133 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996514320373535 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.08751548826694489 │\n", + "│ Loss/Loss_pi/Delta │ 0.08751548826694489 │\n", + "│ Value/Adv │ -0.398242324590683 │\n", + "│ Loss/Loss_reward_critic │ 16605.1796875 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16605.1796875 │\n", + "│ Value/reward │ 0.0049050007946789265 │\n", + "│ Loss/Loss_cost_critic │ 0.052194785326719284 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.052194785326719284 │\n", + "│ Value/cost │ 0.07966174930334091 │\n", + "│ Time/Total │ 0.21084904670715332 │\n", + "│ Time/Rollout │ 0.17566156387329102 │\n", + "│ Time/Update │ 0.03439140319824219 │\n", + "│ Time/Epoch │ 0.21008920669555664 │\n", + "│ Time/FPS │ 951.9786987304688 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1616.242431640625, 0.0, 200.0)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister # Avoid the \"environment has been registered\" error when rerunning cells\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + " pass\n", + "\n", + "\n", + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 200,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 200,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Advanced Usage\n", + "In addition to the aforementioned methods, environments from the community can also take advantage of OmniSafe's capabilities for specifying environment-specific parameters and recording information. We will detail the specific operational methods." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Specifying Specific Parameters\n", + "\n", + "Taking `Pendulum-v1` as an example, according to the Gymnasium documentation, a specific parameter `g`, which stands for gravitational acceleration, can be specified when creating this task. Let's first take a look at its default value:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-17/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-17/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "10.0"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "@env_register\n",
+ "@env_unregister # Avoid the \"environment has been registered\" error when rerunning cells\n",
+ "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
+ " def __getattr__(self, name: str) -> Any:\n",
+ " \"\"\"Get the attribute of the environment.\"\"\"\n",
+ " if name.startswith('_'):\n",
+ " raise AttributeError(f'attempted to get missing private attribute {name}')\n",
+ " return getattr(self._env, name)\n",
+ "\n",
+ "\n",
+ "custom_cfgs = {\n",
+ " 'train_cfgs': {\n",
+ " 'total_steps': 200,\n",
+ " },\n",
+ " 'algo_cfgs': {\n",
+ " 'steps_per_epoch': 200,\n",
+ " 'update_iters': 1,\n",
+ " },\n",
+ "}\n",
+ "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
+ "agent.agent._env._env.g"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We implemented a magic function named `__get_attr__` to call and view specific parameters in the currently instantiated environment. In this case, we find that the default value of the gravitational acceleration `g` is 10.0.\n",
+ "\n",
+ "By consulting the Gymnasium documentation, this parameter can be specified during the process of creating an environment with the `gymnasium.make` function. Does OmniSafe support the passing of specific parameters for customized environments? The answer is yes:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-20/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-20/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "9.8"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "custom_cfgs.update({'env_cfgs': {'g': 9.8}})\n",
+ "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
+ "agent.agent._env._env.g"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Nice! The value of gravitational acceleration has been changed to 9.8. We just need to operate on `env_cfgs`, specifying the key and value of the parameter to be customized, to achieve the passing of specific parameters for the environment."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Information Recording\n",
+ "\n",
+ "The `Pendulum-v1` task contains many specific dynamic pieces of information. We will introduce how to record these pieces of information using OmniSafe's `Logger`. Specifically, we will explain using the maximum and cumulative values of the angular velocity `angular_velocity` per episode as examples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-09-23/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-23/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1607.6717529296875 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.418560266494751 │\n", + "│ Train/KL │ 0.0005777678452432156 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Min │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Max │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Std │ 0.005412393249571323 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996219277381897 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.09192709624767303 │\n", + "│ Loss/Loss_pi/Delta │ 0.09192709624767303 │\n", + "│ Value/Adv │ -0.4177907109260559 │\n", + "│ Loss/Loss_reward_critic │ 16393.2265625 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16393.2265625 │\n", + "│ Value/reward │ 0.00719139538705349 │\n", + "│ Loss/Loss_cost_critic │ 0.05219484493136406 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.05219484493136406 │\n", + "│ Value/cost │ 0.07949987053871155 │\n", + "│ Time/Total │ 0.20513606071472168 │\n", + "│ Time/Rollout │ 0.17486166954040527 │\n", + "│ Time/Update │ 0.029330968856811523 │\n", + "│ Time/Epoch │ 0.20422101020812988 │\n", + "│ Time/FPS │ 979.3323364257812 │\n", + "│ Env/Max_angular_velocity │ 2.9994523525238037 │\n", + "│ Env/Cumulative_angular_velocity │ 1.0643725395202637 │\n", + "│ Metrics/LagrangeMultiplier/Mean │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└─────────────────────────────────┴───────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1607.6717529296875 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.418560266494751 │\n", + "│ Train/KL │ 0.0005777678452432156 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Min │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Max │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Std │ 0.005412393249571323 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996219277381897 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.09192709624767303 │\n", + "│ Loss/Loss_pi/Delta │ 0.09192709624767303 │\n", + "│ Value/Adv │ -0.4177907109260559 │\n", + "│ Loss/Loss_reward_critic │ 16393.2265625 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16393.2265625 │\n", + "│ Value/reward │ 0.00719139538705349 │\n", + "│ Loss/Loss_cost_critic │ 0.05219484493136406 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.05219484493136406 │\n", + "│ Value/cost │ 0.07949987053871155 │\n", + "│ Time/Total │ 0.20513606071472168 │\n", + "│ Time/Rollout │ 0.17486166954040527 │\n", + "│ Time/Update │ 0.029330968856811523 │\n", + "│ Time/Epoch │ 0.20422101020812988 │\n", + "│ Time/FPS │ 979.3323364257812 │\n", + "│ Env/Max_angular_velocity │ 2.9994523525238037 │\n", + "│ Env/Cumulative_angular_velocity │ 1.0643725395202637 │\n", + "│ Metrics/LagrangeMultiplier/Mean │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└─────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1607.6717529296875, 0.0, 200.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from omnisafe.common.logger import Logger\n", + "\n", + "\n", + "@env_register\n", + "@env_unregister # Avoid the \"environment has been registered\" error when rerunning cells\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + "\n", + " def __init__(self, env_id, num_envs, device, **kwargs):\n", + " super().__init__(env_id, num_envs, device, **kwargs)\n", + " self.env_spec_log = {\n", + " 'Env/Max_angular_velocity': 0.0,\n", + " 'Env/Cumulative_angular_velocity': 0.0,\n", + " } # Reiterate and specify in the constructor\n", + "\n", + " def spec_log(self, logger: Logger) -> None:\n", + " for key, value in self.env_spec_log.items():\n", + " logger.store({key: value})\n", + " self.env_spec_log[key] = 0.0\n", + "\n", + " def step(self, action):\n", + " obs, reward, cost, terminated, truncated, info = super().step(action=action)\n", + " angle = obs[-1].item()\n", + " self.env_spec_log['Env/Max_angular_velocity'] = max(\n", + " self.env_spec_log['Env/Max_angular_velocity'], angle\n", + " )\n", + " self.env_spec_log['Env/Cumulative_angular_velocity'] += angle\n", + " return obs, reward, cost, terminated, truncated, info\n", + "\n", + "\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! We successfully recorded the required environment-specific information in the `Logger`. It is worth noting that, in this process, we did not modify any of OmniSafe's source code." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "In this section, using Gymnasium's classic environment `Pendulum-v1`, we introduced the necessary interface adaptation and information provision required to embed an existing community environment into OmniSafe. We hope this tutorial is helpful for embedding your customized environment. If you wish to have your environment supported as one of the official OmniSafe environments, or if you encounter difficulties in customizing environments, you are welcome to communicate with us through the [Issues](https://github.com/PKU-Alignment/omnisafe/issues), [Pull Requests](https://github.com/PKU-Alignment/omnisafe/pulls), and [Discussions](https://github.com/PKU-Alignment/omnisafe/discussions) modules." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/zh-cn/3.Environment Customization from Zero.ipynb b/tutorials/zh-cn/3.Environment Customization from Zero.ipynb new file mode 100644 index 000000000..14f45e272 --- /dev/null +++ b/tutorials/zh-cn/3.Environment Customization from Zero.ipynb @@ -0,0 +1,1440 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OmniSafe Tutorial - Environment Customization From Zero\n", + "\n", + "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n", + "\n", + "Documentation: https://omnisafe.readthedocs.io/en/latest/\n", + "\n", + "Safety-Gymnasium: https://www.safety-gymnasium.com/\n", + "\n", + "[Safety-Gymnasium](https://www.safety-gymnasium.com/) is a highly scalable and customizable Safe Reinforcement Learning library, aiming to deliver a good view of benchmarking Safe Reinforcement Learning (Safe RL) algorithms and a more standardized setting of environments. \n", + "\n", + "## 引言\n", + "\n", + "本节与[Tutorial 4: Environment Customization from Community](./4.Environment%20Customization%20from%20Community.ipynb)共同介绍了如何令定制化环境享受OmniSafe提供的全套训练、记录与保存框架。本节侧重于面向安全强化学习初学者介绍如何从零开始创建环境;而[Tutorial 4: Environment Customization from Community](./4.Environment%20Customization%20from%20Community.ipynb)关注如何将社区已有的环境,例如[Gymnasium](https://github.com/Farama-Foundation/Gymnasium),作出最小适配,以嵌入OmniSafe中。\n", + "\n", + "具体而言,本节提供了一个用于定制化环境的最简单模版。通过该模版,您将了解:\n", + "\n", + "- 如何在OmniSafe中创建并注册一个环境。\n", + "- 如何指定创建环境时的定制化参数。\n", + "- 如何记录环境特定的信息。\n", + "\n", + "## 快速安装" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过pip安装(如果您已经安装,请忽略此段代码)\n", + "%pip install omnisafe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过源代码安装(如果您已经安装,请忽略此段代码)\n", + "## 克隆仓库\n", + "%git clone https://github.com/PKU-Alignment/omnisafe\n", + "%cd omnisafe\n", + "\n", + "## 完成安装\n", + "%pip install -e ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 定制化环境最简模版\n", + "OmniSafe的定制化环境可以仅通过单个文件实现。我们将为您介绍一个最简的定制化环境模版,它将作为您入门的起点。\n", + "\n", + "### 定制化环境设计\n", + "我们将在此细致地介绍一个简易随机环境的设计过程。如果您是强化学习领域的专家或有经验的研究者,可以跳过该模块至[定制化环境嵌入](#定制化环境嵌入)或[Tutorial 4: Environment Customization from Community](./4.Gymnasium%20Customization.ipynb)。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# 导入必要的包\n", + "from __future__ import annotations\n", + "\n", + "import random\n", + "import omnisafe\n", + "from typing import Any, ClassVar\n", + "\n", + "import torch\n", + "from gymnasium import spaces\n", + "\n", + "from omnisafe.envs.core import CMDP, env_register, env_unregister" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# 定义环境类\n", + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "您需要关注上面这段代码的如下细节:\n", + "\n", + "- **任务名称定义** 在 `_support_envs`中提供环境受支持的任务名称。\n", + "- **Wrapper配置** 通过设定 `need_auto_reset_wrapper`和 `need_time_limit_wrapper` 来定义自动重置和限制时间。\n", + "- **并行环境数量** 如果您的环境支持向量化并行,请通过 `_num_envs` 参数进行设定。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成 `__init__`函数定义。此处需要给出环境的动作空间与观测空间。您需要根据您当前在设计的具体任务来定义。例如:\n", + "```python\n", + "if env_id == 'Example-v0':\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "elif env_id == 'Example-v1':\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + "else:\n", + " raise NotImplementedError\n", + "```\n", + "**请注意:** 由于需要为上层模块提供标准的接口,因此在设计环境时请遵循 `self._observation_space` 以及 `self._action_space` 这两个变量名**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成环境初始化相关函数的定义。`reset` 和 `set_seed` 是OmniSafe环境初始化的标准接口。其中 `reset` 重置环境状态与计步器。 `set_seed` 通过设定随机种子确保实验的可复现性。而带有`@property`装饰器的`max_episode_steps`函数用于为`TimeLimit` Wrapper传递需要限制的每幕最大步数。实现参考如下:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成功能性函数的定义。`render` 函数用于渲染环境;`close` 函数用于训练结束后的清理。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10\n", + "\n", + " def render(self) -> Any:\n", + " pass\n", + "\n", + " def close(self) -> None:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "完成 `step` 函数定义。此处是您定制化环境的核心交互逻辑。您只需按照本例中的数据输入与输出格式进行调整即可。您也可以直接将本例中的随机交互动态更改为您的环境动态。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1'] # 受支持的任务名称\n", + " metadata: ClassVar[dict[str, int]] = {}\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " self._count = 0\n", + " self._num_envs = 1\n", + " self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n", + " self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " random.seed(seed)\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict]:\n", + " if seed is not None:\n", + " self.set_seed(seed)\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " self._count = 0\n", + " return obs, {}\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> None:\n", + " \"\"\"The max steps per episode.\"\"\"\n", + " return 10\n", + "\n", + " def render(self) -> Any:\n", + " pass\n", + "\n", + " def close(self) -> None:\n", + " pass\n", + "\n", + " def step(\n", + " self,\n", + " action: torch.Tensor,\n", + " ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:\n", + " self._count += 1\n", + " obs = torch.as_tensor(self._observation_space.sample())\n", + " reward = 2 * torch.as_tensor(random.random())\n", + " cost = 2 * torch.as_tensor(random.random())\n", + " terminated = torch.as_tensor(random.random() > 0.9)\n", + " truncated = torch.as_tensor(self._count > 10)\n", + " return obs, reward, cost, terminated, truncated, {'final_observation': obs}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接下来,我们试着运行该环境10个时间步,观察交互信息。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------\n", + "obs: tensor([-0.5552, 0.2905, 0.0094])\n", + "reward: 1.6888437271118164\n", + "cost: 1.5159088373184204\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([-0.0635, -0.9966, -0.4681])\n", + "reward: 0.5178334712982178\n", + "cost: 1.0225493907928467\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.4385, 0.0678, -0.3470])\n", + "reward: 1.5675971508026123\n", + "cost: 0.6066254377365112\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.8278, -0.5252, -0.1799])\n", + "reward: 1.1667640209197998\n", + "cost: 1.8162257671356201\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([ 0.1086, -0.5711, 0.7751])\n", + "reward: 0.5636757016181946\n", + "cost: 1.511608362197876\n", + "terminated: False\n", + "truncated: False\n", + "********************\n", + "--------------------\n", + "obs: tensor([-0.3585, 0.8011, 0.2172])\n", + "reward: 0.5010126829147339\n", + "cost: 1.8194924592971802\n", + "terminated: True\n", + "truncated: False\n", + "********************\n" + ] + } + ], + "source": [ + "env = ExampleEnv(env_id='Example-v0')\n", + "env.reset(seed=0)\n", + "while True:\n", + " action = env.action_space.sample()\n", + " obs, reward, cost, terminated, truncated, info = env.step(action)\n", + " print('-' * 20)\n", + " print(f'obs: {obs}')\n", + " print(f'reward: {reward}')\n", + " print(f'cost: {cost}')\n", + " print(f'terminated: {terminated}')\n", + " print(f'truncated: {truncated}')\n", + " print('*' * 20)\n", + " if terminated or truncated:\n", + " break\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "恭喜您!已经成功完成了基础的环境定义,接下来,我们将介绍如何将该环境注册入OmniSafe中,并实现环境参数传递、交互信息记录、算法训练以及结果保存等步骤。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 定制化环境嵌入\n", + "\n", + "### 快速训练\n", + "\n", + "得益于OmniSafe精心设计的注册机制,我们只需一个装饰器即可将这个环境注册到OmniSafe的环境列表中。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "@env_register\n", + "class ExampleEnv(ExampleEnv):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "注册同名环境将会报错,这是由于**环境名称冲突**。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "class CustomExampleEnv(ExampleEnv):\n", + " example_configs = 1\n", + "\n", + "\n", + "env = CustomExampleEnv('Example-v0')\n", + "env.example_configs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "这时,您需要先对环境手动取消注册。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "@env_unregister\n", + "class CustomExampleEnv(ExampleEnv):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "之后,您就可以重新注册该环境了。在本教程中,我们会同时嵌套 `env_register` 和 `env_unregister` 装饰器,这是为了避免环境重复注册造成报错,即确保该环境只被注册一次,以便用户在阅读本教程时多次修改与运行代码。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CustomExampleEnv has not been registered yet\n" + ] + }, + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class CustomExampleEnv(ExampleEnv):\n", + " example_configs = 2\n", + "\n", + "\n", + "env = CustomExampleEnv('Example-v0')\n", + "env.example_configs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "随后,您可以使用OmniSafe中的算法来训练这个自定义环境。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Example-v0}/seed-000-2024-04-09-15-04-56/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-04-56/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n", + "\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00020748490351252258 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.00019999999494757503 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -1.4901161193847656e-08 │\n", + "│ Loss/Loss_pi/Delta │ -1.4901161193847656e-08 │\n", + "│ Value/Adv │ 1.4901161193847656e-08 │\n", + "│ Loss/Loss_reward_critic │ 10.458966255187988 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.458966255187988 │\n", + "│ Value/reward │ -0.015489530749619007 │\n", + "│ Loss/Loss_cost_critic │ 19.141571044921875 │\n", + "│ Loss/Loss_cost_critic/Delta │ 19.141571044921875 │\n", + "│ Value/cost │ 0.05426764488220215 │\n", + "│ Time/Total │ 0.034796953201293945 │\n", + "│ Time/Rollout │ 0.01762533187866211 │\n", + "│ Time/Update │ 0.01616811752319336 │\n", + "│ Time/Epoch │ 0.03383183479309082 │\n", + "│ Time/FPS │ 295.5858459472656 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00020748490351252258 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.00019999999494757503 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -1.4901161193847656e-08 │\n", + "│ Loss/Loss_pi/Delta │ -1.4901161193847656e-08 │\n", + "│ Value/Adv │ 1.4901161193847656e-08 │\n", + "│ Loss/Loss_reward_critic │ 10.458966255187988 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.458966255187988 │\n", + "│ Value/reward │ -0.015489530749619007 │\n", + "│ Loss/Loss_cost_critic │ 19.141571044921875 │\n", + "│ Loss/Loss_cost_critic/Delta │ 19.141571044921875 │\n", + "│ Value/cost │ 0.05426764488220215 │\n", + "│ Time/Total │ 0.034796953201293945 │\n", + "│ Time/Rollout │ 0.01762533187866211 │\n", + "│ Time/Update │ 0.01616811752319336 │\n", + "│ Time/Epoch │ 0.03383183479309082 │\n", + "│ Time/FPS │ 295.5858459472656 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 10.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m10.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 7.8531365394592285 │\n", + "│ Metrics/EpCost │ 7.931504726409912 │\n", + "│ Metrics/EpLen │ 6.666666507720947 │\n", + "│ Train/Epoch │ 1.0 │\n", + "│ Train/Entropy │ 1.4192386865615845 │\n", + "│ Train/KL │ 6.416345422621816e-05 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 9.999999747378752e-05 │\n", + "│ Train/PolicyStd │ 1.0003000497817993 │\n", + "│ TotalEnvSteps │ 20.0 │\n", + "│ Loss/Loss_pi │ -6.258487417198921e-08 │\n", + "│ Loss/Loss_pi/Delta │ -4.768371297814156e-08 │\n", + "│ Value/Adv │ 1.341104507446289e-07 │\n", + "│ Loss/Loss_reward_critic │ 38.05686950683594 │\n", + "│ Loss/Loss_reward_critic/Delta │ 27.59790325164795 │\n", + "│ Value/reward │ -0.008213319815695286 │\n", + "│ Loss/Loss_cost_critic │ 23.737285614013672 │\n", + "│ Loss/Loss_cost_critic/Delta │ 4.595714569091797 │\n", + "│ Value/cost │ 0.17113244533538818 │\n", + "│ Time/Total │ 0.0776519775390625 │\n", + "│ Time/Rollout │ 0.015673398971557617 │\n", + "│ Time/Update │ 0.011301994323730469 │\n", + "│ Time/Epoch │ 0.027007579803466797 │\n", + "│ Time/FPS │ 370.27294921875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 7.8531365394592285 │\n", + "│ Metrics/EpCost │ 7.931504726409912 │\n", + "│ Metrics/EpLen │ 6.666666507720947 │\n", + "│ Train/Epoch │ 1.0 │\n", + "│ Train/Entropy │ 1.4192386865615845 │\n", + "│ Train/KL │ 6.416345422621816e-05 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 9.999999747378752e-05 │\n", + "│ Train/PolicyStd │ 1.0003000497817993 │\n", + "│ TotalEnvSteps │ 20.0 │\n", + "│ Loss/Loss_pi │ -6.258487417198921e-08 │\n", + "│ Loss/Loss_pi/Delta │ -4.768371297814156e-08 │\n", + "│ Value/Adv │ 1.341104507446289e-07 │\n", + "│ Loss/Loss_reward_critic │ 38.05686950683594 │\n", + "│ Loss/Loss_reward_critic/Delta │ 27.59790325164795 │\n", + "│ Value/reward │ -0.008213319815695286 │\n", + "│ Loss/Loss_cost_critic │ 23.737285614013672 │\n", + "│ Loss/Loss_cost_critic/Delta │ 4.595714569091797 │\n", + "│ Value/cost │ 0.17113244533538818 │\n", + "│ Time/Total │ 0.0776519775390625 │\n", + "│ Time/Rollout │ 0.015673398971557617 │\n", + "│ Time/Update │ 0.011301994323730469 │\n", + "│ Time/Epoch │ 0.027007579803466797 │\n", + "│ Time/FPS │ 370.27294921875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 9.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m9.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 6.297085762023926 │\n", + "│ Metrics/EpCost │ 6.2187700271606445 │\n", + "│ Metrics/EpLen │ 5.25 │\n", + "│ Train/Epoch │ 2.0 │\n", + "│ Train/Entropy │ 1.419387698173523 │\n", + "│ Train/KL │ 5.490810053743189e-06 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0004494190216064 │\n", + "│ TotalEnvSteps │ 30.0 │\n", + "│ Loss/Loss_pi │ -1.9073486612342094e-07 │\n", + "│ Loss/Loss_pi/Delta │ -1.2814999195143173e-07 │\n", + "│ Value/Adv │ 1.0728835775353218e-07 │\n", + "│ Loss/Loss_reward_critic │ 34.77037811279297 │\n", + "│ Loss/Loss_reward_critic/Delta │ -3.2864913940429688 │\n", + "│ Value/reward │ 0.014150517992675304 │\n", + "│ Loss/Loss_cost_critic │ 27.43436050415039 │\n", + "│ Loss/Loss_cost_critic/Delta │ 3.6970748901367188 │\n", + "│ Value/cost │ 0.24021005630493164 │\n", + "│ Time/Total │ 0.12173724174499512 │\n", + "│ Time/Rollout │ 0.01879405975341797 │\n", + "│ Time/Update │ 0.011112689971923828 │\n", + "│ Time/Epoch │ 0.0299375057220459 │\n", + "│ Time/FPS │ 334.039794921875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 6.297085762023926 │\n", + "│ Metrics/EpCost │ 6.2187700271606445 │\n", + "│ Metrics/EpLen │ 5.25 │\n", + "│ Train/Epoch │ 2.0 │\n", + "│ Train/Entropy │ 1.419387698173523 │\n", + "│ Train/KL │ 5.490810053743189e-06 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0004494190216064 │\n", + "│ TotalEnvSteps │ 30.0 │\n", + "│ Loss/Loss_pi │ -1.9073486612342094e-07 │\n", + "│ Loss/Loss_pi/Delta │ -1.2814999195143173e-07 │\n", + "│ Value/Adv │ 1.0728835775353218e-07 │\n", + "│ Loss/Loss_reward_critic │ 34.77037811279297 │\n", + "│ Loss/Loss_reward_critic/Delta │ -3.2864913940429688 │\n", + "│ Value/reward │ 0.014150517992675304 │\n", + "│ Loss/Loss_cost_critic │ 27.43436050415039 │\n", + "│ Loss/Loss_cost_critic/Delta │ 3.6970748901367188 │\n", + "│ Value/cost │ 0.24021005630493164 │\n", + "│ Time/Total │ 0.12173724174499512 │\n", + "│ Time/Rollout │ 0.01879405975341797 │\n", + "│ Time/Update │ 0.011112689971923828 │\n", + "│ Time/Epoch │ 0.0299375057220459 │\n", + "│ Time/FPS │ 334.039794921875 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴─────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(6.297085762023926, 6.2187700271606445, 5.25)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 30,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 10,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Example-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "干得不错!我们已经完成了这个定制化环境的嵌入和训练。接下来,我们将进一步研究如何为环境指定超参数。\n", + "\n", + "### 参数设定\n", + "\n", + "我们从一个新的示例环境出发,假设这个环境需要传入一个名为 `num_agents` 的参数。我们将展示如何不修改OmniSafe的代码来完成参数设定。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NewExampleEnv has not been registered yet\n" + ] + } + ], + "source": [ + "@env_register\n", + "@env_unregister\n", + "class NewExampleEnv(ExampleEnv): # 创造一个新环境\n", + " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n", + " num_agents: ClassVar[int] = 1\n", + "\n", + " def __init__(self, env_id: str, **kwargs) -> None:\n", + " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n", + " self.num_agents = kwargs.get('num_agents', 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "此时,`num_agents` 参数为预设值:`1`。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_env = NewExampleEnv('NewExample-v0')\n", + "new_env.num_agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "下面我们将展示如何通过 OmniSafe 的接口对该参数进行修改并训练:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-05-09/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-09/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "2"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "custom_cfgs.update({'env_cfgs': {'num_agents': 2}})\n",
+ "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n",
+ "agent.agent._env._env.num_agents"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "非常好!我们将 `num_agents` 设置为了2。这表示我们在未修改代码的情形下成功实现了超参数设定。\n",
+ "\n",
+ "### 训练信息记录\n",
+ "\n",
+ "在运行训练代码时,您可能已经发现 OmniSafe 通过 `Logger` 记录了训练信息,例如:\n",
+ "\n",
+ "```bash\n",
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Metrics ┃ Value ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ Metrics/EpRet │ 2.046875 │\n",
+ "│ Metrics/EpCost │ 2.89453125 │\n",
+ "│ Metrics/EpLen │ 3.25 │\n",
+ "│ Train/Epoch │ 3.0 │\n",
+ "...\n",
+ "```\n",
+ "那么我们可否将环境之中的信息输出到日志中呢?答案是肯定的,而且这个过程同样不需要修改OmniSafe的代码。只需要实现两个标准接口:\n",
+ "1. 在 `__init__` 函数中,将需要输出的信息添加到`self.env_spec_log`中。\n",
+ "2. 实例化 `spec_log` 函数,记录所需的信息。\n",
+ "\n",
+ "**请注意:** 目前OmniSafe仅支持在每一个epoch结束时记录这些信息,而不支持在每一个step结束时记录。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@env_register\n",
+ "@env_unregister\n",
+ "class NewExampleEnv(ExampleEnv):\n",
+ " _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n",
+ "\n",
+ " # 定义需要记录的信息\n",
+ " def __init__(self, env_id: str, **kwargs) -> None:\n",
+ " super(NewExampleEnv, self).__init__(env_id, **kwargs)\n",
+ " self.env_spec_log = {'Env/Success_counts': 0}\n",
+ "\n",
+ " # 通过step函数,与环境进行交互\n",
+ " def step(self, action):\n",
+ " obs, reward, cost, terminated, truncated, info = super().step(action)\n",
+ " success = int(reward > cost)\n",
+ " self.env_spec_log['Env/Success_counts'] += success\n",
+ " return obs, reward, cost, terminated, truncated, info\n",
+ "\n",
+ " # 在logger中记录环境信息\n",
+ " def spec_log(self, logger) -> dict[str, Any]:\n",
+ " logger.store({'Env/Success_counts': self.env_spec_log['Env/Success_counts']})\n",
+ " self.env_spec_log['Env/Success_counts'] = 0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "接下来,我们简单训练观察该信息是否被成功记录。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Logging data to ./runs/PPOLag-{NewExample-v0}/seed-000-2024-04-09-15-05-14/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-14/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00024281258811242878 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -5.662441182607836e-08 │\n", + "│ Loss/Loss_pi/Delta │ -5.662441182607836e-08 │\n", + "│ Value/Adv │ 1.2814999195143173e-07 │\n", + "│ Loss/Loss_reward_critic │ 10.477845191955566 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.477845191955566 │\n", + "│ Value/reward │ -0.0091781010851264 │\n", + "│ Loss/Loss_cost_critic │ 18.525999069213867 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.525999069213867 │\n", + "│ Value/cost │ 0.14141643047332764 │\n", + "│ Time/Total │ 0.030597209930419922 │\n", + "│ Time/Rollout │ 0.017596960067749023 │\n", + "│ Time/Update │ 0.012219905853271484 │\n", + "│ Time/Epoch │ 0.02985072135925293 │\n", + "│ Time/FPS │ 335.00830078125 │\n", + "│ Env/Success_counts │ 1.5 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ 5.625942230224609 │\n", + "│ Metrics/EpCost │ 6.960921287536621 │\n", + "│ Metrics/EpLen │ 5.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4189385175704956 │\n", + "│ Train/KL │ 0.00024281258811242878 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 1.0 │\n", + "│ Train/PolicyRatio/Min │ 1.0 │\n", + "│ Train/PolicyRatio/Max │ 1.0 │\n", + "│ Train/PolicyRatio/Std │ 0.0 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 1.0 │\n", + "│ TotalEnvSteps │ 10.0 │\n", + "│ Loss/Loss_pi │ -5.662441182607836e-08 │\n", + "│ Loss/Loss_pi/Delta │ -5.662441182607836e-08 │\n", + "│ Value/Adv │ 1.2814999195143173e-07 │\n", + "│ Loss/Loss_reward_critic │ 10.477845191955566 │\n", + "│ Loss/Loss_reward_critic/Delta │ 10.477845191955566 │\n", + "│ Value/reward │ -0.0091781010851264 │\n", + "│ Loss/Loss_cost_critic │ 18.525999069213867 │\n", + "│ Loss/Loss_cost_critic/Delta │ 18.525999069213867 │\n", + "│ Value/cost │ 0.14141643047332764 │\n", + "│ Time/Total │ 0.030597209930419922 │\n", + "│ Time/Rollout │ 0.017596960067749023 │\n", + "│ Time/Update │ 0.012219905853271484 │\n", + "│ Time/Epoch │ 0.02985072135925293 │\n", + "│ Time/FPS │ 335.00830078125 │\n", + "│ Env/Success_counts │ 1.5 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(5.625942230224609, 6.960921287536621, 5.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_cfgs.update({'train_cfgs': {'total_steps': 10}})\n", + "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "漂亮!上述代码将在终端输出了环境特化的信息 `Env/Success_counts`。这一过程并不需要对原代码作出改动。\n", + "\n", + "## 总结\n", + "OmniSafe旨在成为安全强化学习的基础软件。我们将持续完善OmniSafe的环境接口标准,使OmniSafe能够适应各种安全强化学习任务,赋能多元安全场景。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/zh-cn/4.Environment Customization from Community.ipynb b/tutorials/zh-cn/4.Environment Customization from Community.ipynb new file mode 100644 index 000000000..6e05bd036 --- /dev/null +++ b/tutorials/zh-cn/4.Environment Customization from Community.ipynb @@ -0,0 +1,903 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OmniSafe Tutorial - Environment Customization from Community\n", + "\n", + "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n", + "\n", + "Documentation: https://omnisafe.readthedocs.io/en/latest/\n", + "\n", + "Gymnasium: https://github.com/Farama-Foundation/Gymnasium\n", + "\n", + "[Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API.\n", + "\n", + "## 引言\n", + "\n", + "在本节当中,我们将为您介绍如何将一个来自社区的已有环境嵌入OmniSafe中。[Gymnasium](https://github.com/Farama-Foundation/Gymnasium)提供的系列任务已被广泛应用至强化学习中。具体而言,本节将以[Pendulum-v1](https://gymnasium.farama.org/environments/classic_control/pendulum/)为例,展示如何将Gymnasium的任务嵌入OmniSafe。\n", + "\n", + "## 快速安装" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过pip安装(如果您已经安装,请忽略此段代码)\n", + "%pip install omnisafe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 通过源代码安装(如果您已经安装,请忽略此段代码)\n", + "## 克隆仓库\n", + "%git clone https://github.com/PKU-Alignment/omnisafe\n", + "%cd omnisafe\n", + "\n", + "## 完成安装\n", + "%pip install -e ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gymnasium任务嵌入\n", + "环境嵌入需要的核心是为SafeRL智能体交互与训练提供足够的静态或动态信息,本节将详细介绍嵌入环境所必须定义的变量以及相应规范。我们将首先按照编写代码的逻辑顺序地展示整个嵌入过程,让您有一个初步的了解。然后我们将回顾所有代码,总结并整理您在自定义环境时需要进行的适配。\n", + "\n", + "\n", + "### 快速开始\n", + "首先,导入本教程所需要的所有外部变量。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# 导入必要的包\n", + "from __future__ import annotations\n", + "\n", + "from typing import Any, ClassVar\n", + "import gymnasium\n", + "import torch\n", + "import numpy as np\n", + "import omnisafe\n", + "\n", + "from omnisafe.envs.core import CMDP, env_register, env_unregister\n", + "from omnisafe.typing import DEVICE_CPU" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "随后,创建一个名为`ExampleMuJoCoEnv`的类,它需要继承的父类是`CMDP`。(这是因为我们想把环境的交互形式转换为CMDP的范式,您可以根据需要定义新的抽象类以实现新的范式)。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class ExampleMuJoCoEnv(CMDP):\n", + " _support_envs: ClassVar[list[str]] = ['Pendulum-v1'] # 支持的任务名称\n", + "\n", + " need_auto_reset_wrapper = True # 是否需要 `AutoReset` Wrapper\n", + " need_time_limit_wrapper = True # 是否需要 `TimeLimit` Wrapper\n", + "\n", + " def __init__(\n", + " self,\n", + " env_id: str,\n", + " num_envs: int = 1,\n", + " device: torch.device = DEVICE_CPU,\n", + " **kwargs: Any,\n", + " ) -> None:\n", + " super().__init__(env_id)\n", + " self._num_envs = num_envs\n", + " self._env = gymnasium.make(id=env_id, autoreset=True, **kwargs) # 实例化环境对象\n", + " self._action_space = self._env.action_space # 指定动作空间,以供算法层初始化读取\n", + " self._observation_space = self._env.observation_space # 指定观测空间,以供算法层初始化读取\n", + " self._device = device # 可选项,使用GPU加速。默认为CPU\n", + "\n", + " def reset(\n", + " self,\n", + " seed: int | None = None,\n", + " options: dict[str, Any] | None = None,\n", + " ) -> tuple[torch.Tensor, dict[str, Any]]:\n", + " obs, info = self._env.reset(seed=seed, options=options) # 重置环境\n", + " return (\n", + " torch.as_tensor(obs, dtype=torch.float32, device=self._device),\n", + " info,\n", + " ) # 将重置后的观测转换为torch tensor。\n", + "\n", + " @property\n", + " def max_episode_steps(self) -> int | None:\n", + " return self._env.env.spec.max_episode_steps # 返回环境每一幕的最大交互步数\n", + "\n", + " def set_seed(self, seed: int) -> None:\n", + " self.reset(seed=seed) # 设定环境的随机种子以实现可复现性\n", + "\n", + " def render(self) -> Any:\n", + " return self._env.render() # 返回环境渲染的图像\n", + "\n", + " def close(self) -> None:\n", + " self._env.close() # 训练结束后,释放环境实例\n", + "\n", + " def step(\n", + " self,\n", + " action: torch.Tensor,\n", + " ) -> tuple[\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " torch.Tensor,\n", + " dict[str, Any],\n", + " ]:\n", + " obs, reward, terminated, truncated, info = self._env.step(\n", + " action.detach().cpu().numpy(),\n", + " ) # 读取与环境交互后的动态信息\n", + " cost = np.zeros_like(reward) # Gymnasium并显式包含安全约束,此处仅为占位。\n", + " obs, reward, cost, terminated, truncated = (\n", + " torch.as_tensor(x, dtype=torch.float32, device=self._device)\n", + " for x in (obs, reward, cost, terminated, truncated)\n", + " ) # 将动态信息转换为torch tensor。\n", + " if 'final_observation' in info:\n", + " info['final_observation'] = np.array(\n", + " [\n", + " array if array is not None else np.zeros(obs.shape[-1])\n", + " for array in info['final_observation']\n", + " ],\n", + " )\n", + " info['final_observation'] = torch.as_tensor(\n", + " info['final_observation'],\n", + " dtype=torch.float32,\n", + " device=self._device,\n", + " ) # 将info中记录的上一幕final observation转换为torch tensor。\n", + "\n", + " return obs, reward, cost, terminated, truncated, info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "有关上述代码的具体含义,我们已提供了详细的注释说明。更详细的解释可参考[Tutorial 3: Environment Customization from Zero](./3.Environment%20Customization.ipynb)。我们将要点总结如下:\n", + "\n", + "- **OmniSafe初始化需要的静态变量**\n", + "\n", + "| 静态信息 | 必须 | 定义 | 类型 | 例子 |\n", + "|:---:|:---:|:---:|:---:|:---:|\n", + "| `need_auto_reset_wrapper` | 是 | 是否需要 `AutoReset` Wrapper | `bool`变量 | `True` |\n", + "| `need_time_limit_wrapper` | 是 | 是否需要 `TimeLimit` Wrapper | `bool`变量 | `True` |\n", + "| `_action_space` | 是 | 动作空间 | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(2,)` |\n", + "| `_observation_space` | 是 | 观测空间 | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(3,)` |\n", + "| `max_episode_steps` | 是 | 环境每一幕的最大交互步数 | 带有`@property`装饰器的,返回值为`int`或`None`类型变量的函数 | 参考上方代码块 |\n", + "| `_num_envs` | 否 | 并行环境数 | `int`变量 | 5 |\n", + "| `_device` | 否 | torch计算设备 | `torch.device`变量 | `DEVICE_CPU` |\n", + "\n", + "- **OmniSafe需要环境提供的动态变量**\n", + "\n", + "OmniSafe的智能体主要通过`reset`和`step`函数与环境进行动态交互。您需要确保定制化环境的返回值类型、个数与顺序与上述例子一致,更具体地:\n", + "\n", + "| 动态信息 | 类型 | 个数 | 顺序 |\n", + "|:---:|:---:|:---:|:---:|\n", + "| `step` | `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]` | 6 | `obs`, `reward`, `cost`, `terminated`, `truncated`, `info` |\n", + "| `reset` | `tuple[torch.Tensor, dict[str, Any]]` | 2 | `obs`, `info` |\n", + "\n", + "- **注意事项**\n", + "\n", + "1. 尽管`_num_envs`与`_device`并不是必须指定的,但也请您在`__init__`函数中保留这两个参数的输入接口。\n", + "2. `_num_envs`是实例化多个环境并行采样的高级参数,它表示实例化环境的数目。如果您的定制化环境同样支持并行数指定,请通过`_num_envs`指定,而不用再定义一个新的接口。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "随后,将上述环境通过注册装饰器`@env_register`注册入OmniSafe中,即可完成训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ExampleMuJoCoEnv has not been registered yet\n", + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-05-55/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-55/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n", + "\n" + ], + "text/plain": [ + "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1616.242431640625 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4185898303985596 │\n", + "│ Train/KL │ 0.0007516025798395276 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Min │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Max │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Std │ 0.0075334208086133 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996514320373535 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.08751548826694489 │\n", + "│ Loss/Loss_pi/Delta │ 0.08751548826694489 │\n", + "│ Value/Adv │ -0.398242324590683 │\n", + "│ Loss/Loss_reward_critic │ 16605.1796875 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16605.1796875 │\n", + "│ Value/reward │ 0.0049050007946789265 │\n", + "│ Loss/Loss_cost_critic │ 0.052194785326719284 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.052194785326719284 │\n", + "│ Value/cost │ 0.07966174930334091 │\n", + "│ Time/Total │ 0.2075355052947998 │\n", + "│ Time/Rollout │ 0.1734788417816162 │\n", + "│ Time/Update │ 0.033020973205566406 │\n", + "│ Time/Epoch │ 0.20653653144836426 │\n", + "│ Time/FPS │ 968.3539428710938 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴───────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1616.242431640625 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.4185898303985596 │\n", + "│ Train/KL │ 0.0007516025798395276 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Min │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Max │ 0.9966228604316711 │\n", + "│ Train/PolicyRatio/Std │ 0.0075334208086133 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996514320373535 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.08751548826694489 │\n", + "│ Loss/Loss_pi/Delta │ 0.08751548826694489 │\n", + "│ Value/Adv │ -0.398242324590683 │\n", + "│ Loss/Loss_reward_critic │ 16605.1796875 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16605.1796875 │\n", + "│ Value/reward │ 0.0049050007946789265 │\n", + "│ Loss/Loss_cost_critic │ 0.052194785326719284 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.052194785326719284 │\n", + "│ Value/cost │ 0.07966174930334091 │\n", + "│ Time/Total │ 0.2075355052947998 │\n", + "│ Time/Rollout │ 0.1734788417816162 │\n", + "│ Time/Update │ 0.033020973205566406 │\n", + "│ Time/Epoch │ 0.20653653144836426 │\n", + "│ Time/FPS │ 968.3539428710938 │\n", + "│ Metrics/LagrangeMultiplier/Mea │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1616.242431640625, 0.0, 200.0)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@env_register\n", + "@env_unregister # 避免重复运行单元格时产生\"环境已注册\"报错\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + " pass\n", + "\n", + "\n", + "custom_cfgs = {\n", + " 'train_cfgs': {\n", + " 'total_steps': 200,\n", + " },\n", + " 'algo_cfgs': {\n", + " 'steps_per_epoch': 200,\n", + " 'update_iters': 1,\n", + " },\n", + "}\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 高级使用\n", + "除了上述使用方式外,来自社区的环境还可以享受OmniSafe的环境特定参数指定以及信息记录的特性。我们将详细展示具体操作方式。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 特定参数指定\n", + "\n", + "以`Pendulum-v1`为例,根据Gymnasium的官方文档,创建该任务时可指定一个特定参数为`g`,即重力加速度。我们首先来看看它的默认取值:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n" + ] + }, + { + "data": { + "text/html": [ + "
Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-05-58/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-58/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "10.0"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "@env_register\n",
+ "@env_unregister # 避免重复运行单元格时产生\"环境已注册\"报错\n",
+ "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
+ " def __getattr__(self, name: str) -> Any:\n",
+ " \"\"\"Get the attribute of the environment.\"\"\"\n",
+ " if name.startswith('_'):\n",
+ " raise AttributeError(f'attempted to get missing private attribute {name}')\n",
+ " return getattr(self._env, name)\n",
+ "\n",
+ "\n",
+ "custom_cfgs = {\n",
+ " 'train_cfgs': {\n",
+ " 'total_steps': 200,\n",
+ " },\n",
+ " 'algo_cfgs': {\n",
+ " 'steps_per_epoch': 200,\n",
+ " 'update_iters': 1,\n",
+ " },\n",
+ "}\n",
+ "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
+ "agent.agent._env._env.g"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "我们实现了一个名为`__get_attr__`的魔法函数,用于调用并查看当前实例化的环境中的特定参数。在本例中,我们发现重力加速度`g`的默认值是10.0\n",
+ "\n",
+ "通过查阅Gymnasium的文档,该参数可以在调用`gymnasium.make`函数创建环境的过程中指定。OmniSafe是否支持定制化环境的特定参数传递呢?答案是肯定的,具体操作也非常简单:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-06-01/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-06-01/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "9.8"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "custom_cfgs.update({'env_cfgs': {'g': 9.8}})\n",
+ "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
+ "agent.agent._env._env.g"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "非常好!重力加速度取值被我们更改为了9.8。我们只需要对`env_cfgs`进行操作,将需要定制参数的键与值指定,即可实现环境的特定参数传递。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 信息记录\n",
+ "\n",
+ "`Pendulum-v1`任务有许多特定的动态信息,我们将为您介绍如何通过OmniSafe的`Logger`记录这些信息。具体而言,我们将以每幕角速度`angular_velocity`的最大值以及累计值为例为您讲解。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Logging data to ./runs/PPOLag-{Pendulum-v1}/seed-000-2024-04-09-15-06-03/progress.csv\n", + "\n" + ], + "text/plain": [ + "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-06-03/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Save with config in config.json\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1;33mSave with config in config.json\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "INFO: Start training\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32mINFO: Start training\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Warning: trajectory cut off when rollout by epoch at 200.0 steps.\n", + "\n" + ], + "text/plain": [ + "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Metrics ┃ Value ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1607.6717529296875 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.418560266494751 │\n", + "│ Train/KL │ 0.0005777678452432156 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Min │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Max │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Std │ 0.005412393249571323 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996219277381897 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.09192709624767303 │\n", + "│ Loss/Loss_pi/Delta │ 0.09192709624767303 │\n", + "│ Value/Adv │ -0.4177907109260559 │\n", + "│ Loss/Loss_reward_critic │ 16393.2265625 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16393.2265625 │\n", + "│ Value/reward │ 0.00719139538705349 │\n", + "│ Loss/Loss_cost_critic │ 0.05219484493136406 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.05219484493136406 │\n", + "│ Value/cost │ 0.07949987053871155 │\n", + "│ Time/Total │ 0.2163846492767334 │\n", + "│ Time/Rollout │ 0.18010711669921875 │\n", + "│ Time/Update │ 0.03433847427368164 │\n", + "│ Time/Epoch │ 0.21448636054992676 │\n", + "│ Time/FPS │ 932.4664306640625 │\n", + "│ Env/Max_angular_velocity │ 2.9994523525238037 │\n", + "│ Env/Cumulative_angular_velocity │ 1.0643725395202637 │\n", + "│ Metrics/LagrangeMultiplier/Mean │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└─────────────────────────────────┴───────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mMetrics \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Metrics/EpRet │ -1607.6717529296875 │\n", + "│ Metrics/EpCost │ 0.0 │\n", + "│ Metrics/EpLen │ 200.0 │\n", + "│ Train/Epoch │ 0.0 │\n", + "│ Train/Entropy │ 1.418560266494751 │\n", + "│ Train/KL │ 0.0005777678452432156 │\n", + "│ Train/StopIter │ 1.0 │\n", + "│ Train/PolicyRatio/Mean │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Min │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Max │ 0.9981198310852051 │\n", + "│ Train/PolicyRatio/Std │ 0.005412393249571323 │\n", + "│ Train/LR │ 0.0 │\n", + "│ Train/PolicyStd │ 0.9996219277381897 │\n", + "│ TotalEnvSteps │ 200.0 │\n", + "│ Loss/Loss_pi │ 0.09192709624767303 │\n", + "│ Loss/Loss_pi/Delta │ 0.09192709624767303 │\n", + "│ Value/Adv │ -0.4177907109260559 │\n", + "│ Loss/Loss_reward_critic │ 16393.2265625 │\n", + "│ Loss/Loss_reward_critic/Delta │ 16393.2265625 │\n", + "│ Value/reward │ 0.00719139538705349 │\n", + "│ Loss/Loss_cost_critic │ 0.05219484493136406 │\n", + "│ Loss/Loss_cost_critic/Delta │ 0.05219484493136406 │\n", + "│ Value/cost │ 0.07949987053871155 │\n", + "│ Time/Total │ 0.2163846492767334 │\n", + "│ Time/Rollout │ 0.18010711669921875 │\n", + "│ Time/Update │ 0.03433847427368164 │\n", + "│ Time/Epoch │ 0.21448636054992676 │\n", + "│ Time/FPS │ 932.4664306640625 │\n", + "│ Env/Max_angular_velocity │ 2.9994523525238037 │\n", + "│ Env/Cumulative_angular_velocity │ 1.0643725395202637 │\n", + "│ Metrics/LagrangeMultiplier/Mean │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Min │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Max │ 0.0 │\n", + "│ Metrics/LagrangeMultiplier/Std │ 0.0 │\n", + "└─────────────────────────────────┴───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(-1607.6717529296875, 0.0, 200.0)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from omnisafe.common.logger import Logger\n", + "\n", + "\n", + "@env_register\n", + "@env_unregister # 避免重复运行单元格时产生\"环境已注册\"报错\n", + "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n", + "\n", + " def __init__(self, env_id, num_envs, device, **kwargs):\n", + " super().__init__(env_id, num_envs, device, **kwargs)\n", + " self.env_spec_log = {\n", + " 'Env/Max_angular_velocity': 0.0,\n", + " 'Env/Cumulative_angular_velocity': 0.0,\n", + " } # 在构造函数中重申并指定\n", + "\n", + " def spec_log(self, logger: Logger) -> None:\n", + " for key, value in self.env_spec_log.items():\n", + " logger.store({key: value})\n", + " self.env_spec_log[key] = 0.0\n", + "\n", + " def step(self, action):\n", + " obs, reward, cost, terminated, truncated, info = super().step(action=action)\n", + " angle = obs[-1].item()\n", + " self.env_spec_log['Env/Max_angular_velocity'] = max(\n", + " self.env_spec_log['Env/Max_angular_velocity'], angle\n", + " )\n", + " self.env_spec_log['Env/Cumulative_angular_velocity'] += angle\n", + " return obs, reward, cost, terminated, truncated, info\n", + "\n", + "\n", + "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n", + "agent.learn()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "太好了!我们成功地在`Logger`中记录了需要的环境特定信息。值得注意的是,在这一过程中我们并没有修改OmniSafe的任何源代码。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 总结\n", + "我们在本节使用了Gymnasium的经典环境`Pendulum-v1`,为您介绍了将一个社区已有的环境嵌入OmniSafe中所需的必要接口适配与信息提供。我们希望这个教程对您的定制化环境嵌入过程有帮助。如果您想将自己的环境作为OmniSafe官方支持的环境之一,或者在定制化环境中遇到了困难,欢迎在[Issues](https://github.com/PKU-Alignment/omnisafe/issues),[Pull Requests](https://github.com/PKU-Alignment/omnisafe/pulls)与[Discussions](https://github.com/PKU-Alignment/omnisafe/discussions)模块与我们沟通。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "omnisafe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}