-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy.py
30 lines (23 loc) · 893 Bytes
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import random
from typing import Dict
from action import Action
from state import State
class Policy(dict):
def __init__(self) -> None:
pass
def __str__(self) -> str:
lines = []
for s in self.keys():
lines.append(f"{s}\n")
for a in self[s].keys():
lines.append(f" {a}: {self[s][a]}")
lines.append("\n")
return "".join(lines)
def set_action_probs(self, state: State, action_probs: Dict[Action, float]):
self[state] = action_probs
def get_action_prob(self, state: State, action: Action) -> float:
action_probs: Dict[Action] = self.get(state)
return action_probs[action]
def get_action(self, state):
action_probs: Dict[Action] = self.get(state)
return random.choice(list(action_probs.keys()), weights=list(action_probs.values()))[0]