-
Notifications
You must be signed in to change notification settings - Fork 2
/
example_reinforcement_learning.py
29 lines (22 loc) · 1.17 KB
/
example_reinforcement_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (c) Michael Mazanetz (NovaData Solutions LTD.), Silvia Amabilino (NovaData Solutions LTD.,
# University of Bristol), David Glowacki (University of Bristol). All rights reserved.
# Licensed under the GPL. See LICENSE in the project root for license information.
"""
This example shows how to use reinforcement learning to refine an already trained recurrent neural network. It requires
having run example_training.py first.
"""
from molbot import reinforcement_learning, rewards
import os
# Import the data and parse it
current_dir = os.path.dirname(os.path.realpath(__file__))
model_file = os.path.join(current_dir, "example-model.h5")
data_handler_file = os.path.join(current_dir, "example-dp.pickle")
# Creating the reinforcement learning object
reward_f = rewards.calculate_tpsa_reward
rl = reinforcement_learning.Reinforcement_learning(model_file=model_file,
data_handler_file=data_handler_file,
reward_function=reward_f)
# Running the reinforcement learning
rl.train(temperature=0.75, epochs=4, n_train_episodes=15, sigma=60)
# Saving the new model
rl.save("rl_model.h5")