Skip to content

Commit

Permalink
Temp push
Browse files Browse the repository at this point in the history
  • Loading branch information
HakiRose committed Feb 26, 2020
1 parent e4c6812 commit 9d3f78e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
6 changes: 5 additions & 1 deletion textworld/challenges/spaceship/agent_design_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _discount_rewards(self, last_values):

def play(agent, path, max_step=50, nb_episodes=10, verbose=True):
"""
This code uses the cooking agent design in the spaceship game.
This code uses the agent design in the spaceship game.
:param agent: the obj of NeuralAgent, a sample object for the agent
:param path: The path to the game (envo model)
Expand Down Expand Up @@ -302,9 +302,13 @@ def play(agent, path, max_step=50, nb_episodes=10, verbose=True):
nb_moves = 0
while not done:
command = agent.act(obs, score, done, infos)
print(command, "....", end="")
obs, score, done, infos = env.step(command)
nb_moves += 1
agent.act(obs, score, done, infos) # Let the agent know the game is done.
print(score)
print(obs)
print('-------------------------------------')

if verbose:
print(".", end="")
Expand Down
14 changes: 9 additions & 5 deletions textworld/challenges/spaceship/build_agent_TW_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class NeuralAgent:
GAMMA = 0.9

def __init__(self) -> None:
self._initialized = False
self._epsiode_has_started = False
# self._initialized = False
# self._epsiode_has_started = False
self.id2word = ["<PAD>", "<UNK>"]
self.word2id = {w: i for i, w in enumerate(self.id2word)}

Expand Down Expand Up @@ -138,6 +138,7 @@ def act(self, obs: str, score: int, done: bool, infos: Mapping[str, Any]) -> Opt
self.transitions[-1][0] = reward # Update reward information.

self.stats["max"]["score"].append(score)

if self.no_train_step % self.UPDATE_FREQUENCY == 0:
# Update model
returns, advantages = self._discount_rewards(values)
Expand Down Expand Up @@ -252,10 +253,12 @@ def play(agent, path, max_step=50, nb_episodes=10, verbose=True):
print(os.path.basename(path), end="")

# Collect some statistics: nb_steps, final reward.
avg_moves, avg_scores, avg_norm_scores, seed_h = [], [], [], None
avg_moves, avg_scores, avg_norm_scores, seed_h = [], [], [], 4567
for no_episode in range(nb_episodes):
obs, infos = env.reset() # Start new episode.
seed_h = env.env.textworld_env._wrapped_env.seed(init_seed=seed_h)

env.env.textworld_env._wrapped_env.seed(seed=seed_h)
seed_h += 1

score = 0
done = False
Expand Down Expand Up @@ -288,7 +291,8 @@ def play(agent, path, max_step=50, nb_episodes=10, verbose=True):
agent.train() # Tell the agent it should update its parameters.

starttime = time()
play(agent, "./games/levelMedium.ulx", nb_episodes=500, verbose=False) # Medium level game.
print(os.path.realpath("./games/levelMedium.ulx"))
play(agent, "./games/levelMedium.ulx", nb_episodes=25, verbose=False) # Medium level game.
print("Trained in {:.2f} secs".format(time() - starttime))

print('============== Time To Test ============== ')
Expand Down

0 comments on commit 9d3f78e

Please sign in to comment.