forked from turing-usp/Trainee-RL-2021
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request turing-usp#6 from turing-usp/slime-volleyball
Adiciona tarefa extra
- Loading branch information
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "SlimeVolleyball.ipynb", | ||
"provenance": [], | ||
"collapsed_sections": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "VYIynIatFL8y" | ||
}, | ||
"source": [ | ||
"# Treinando seu modelo de vôlei\n", | ||
"\n", | ||
"![](https://camo.githubusercontent.com/7eba6ff826871a5f1b9fb48d5dc7472dbf6bdbcda80bfb975bdde5a0ff71fdf5/68747470733a2f2f6f746f726f2e6e65742f696d672f736c696d6567796d2f706978656c2e676966)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "eO3Jc2pOXLgD" | ||
}, | ||
"source": [ | ||
"# Instala as bibliotecas necessárias\n", | ||
"!pip install stable_baselines3\n", | ||
"!pip install git+https://github.com/turing-usp/slimevolleygym" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "TibOzPxrFRKB" | ||
}, | ||
"source": [ | ||
"Vamos salvar os modelos treinados no seu drive: 😈😈😈" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "xf0FPkw85huS" | ||
}, | ||
"source": [ | ||
"from google.colab import drive\n", | ||
"\n", | ||
"# Permite acesso ao seu drive\n", | ||
"drive.mount('/content/drive')" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "ucaK_ds0XSh6" | ||
}, | ||
"source": [ | ||
"import os\n", | ||
"import gym\n", | ||
"import slimevolleygym\n", | ||
"from slimevolleygym import SurvivalRewardEnv\n", | ||
"from torch import nn\n", | ||
"\n", | ||
"from stable_baselines3.ppo import PPO\n", | ||
"from stable_baselines3.common.callbacks import EvalCallback\n", | ||
"\n", | ||
"# Parâmetros para o treinamento\n", | ||
"SEED = 721 # Seed\n", | ||
"NUM_TIMESTEPS = int(2e7) # Número de timesteps de treino\n", | ||
"EVAL_FREQ = 100000 # A cada quantos timesteps o modelo é avaliado (ver EvalCallback)\n", | ||
"EVAL_EPISODES = 100 # Por quantos episódios o modelo é avaliado (ver EvalCallback)\n", | ||
"LOGDIR = \"/content/drive/MyDrive/ppo\" # Local de salvar o modelo (no seu drive)\n", | ||
"\n", | ||
"# Cria o ambiente\n", | ||
"env = gym.make(\"SlimeVolley-v0\")\n", | ||
"env.seed(SEED)\n", | ||
"\n", | ||
"# Arquitetura do modelo de PPO (usar caso somente tenha interesse em alterar o padrão de 64x64 neurônios)\n", | ||
"model_arch=dict(\n", | ||
" log_std_init=-2,\n", | ||
" ortho_init=False,\n", | ||
" activation_fn=nn.ReLU,\n", | ||
" net_arch=[dict(pi=[64, 64], vf=[64, 64])] # possível trocar para [128, 128] e outros do tipo\n", | ||
" )\n", | ||
"\n", | ||
"# Cria o modelo (mudar hiperparâmetros a vontade)\n", | ||
"model = PPO('MlpPolicy', env, n_steps=4096, batch_size=32, ent_coef=0.005, n_epochs=10,\n", | ||
" learning_rate=3e-4, clip_range=0.2, gamma=0.99, gae_lambda=0.95, verbose=2\n", | ||
" # , policy_kwargs = model_arch # Descomentar caso tenha interesse em usar outras arquiteturas!!!\n", | ||
" )\n", | ||
"\n", | ||
"# Carrega modelo salvo caso já exista\n", | ||
"if os.path.exists(LOGDIR + \"/best_model\"):\n", | ||
" model = PPO.load(LOGDIR + \"/best_model\", env=env, n_steps=4096, n_epochs=10)\n", | ||
"\n", | ||
"# Salva o melhor modelo a cada avaliação\n", | ||
"eval_callback = EvalCallback(env, best_model_save_path=LOGDIR, log_path=LOGDIR, eval_freq=EVAL_FREQ, n_eval_episodes=EVAL_EPISODES)\n", | ||
"\n", | ||
"# Treinamento do Modelo\n", | ||
"model.learn(total_timesteps=NUM_TIMESTEPS, callback=eval_callback)\n", | ||
"\n", | ||
"model.save(os.path.join(LOGDIR, \"final_model\"))\n", | ||
"\n", | ||
"env.close()" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.