-
Notifications
You must be signed in to change notification settings - Fork 2
/
experience_replay.py
47 lines (38 loc) · 1.48 KB
/
experience_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from utils import random_index
from collections import deque
class ExperienceReplay:
def __init__(self, capacity=None):
self.capacity = capacity
def __len__(self):
try:
return len(next(iter(self._data.values()))) # length of any value of the dict (i.e. any of the lists)
except AttributeError:
return 0
def items(self):
try:
return self._data.items()
except AttributeError:
raise ValueError("Trying to iterate over an empty dataset")
def add_column(self, name, examples):
try:
assert name not in self._data.keys()
assert len(self) == 0 or len(self) == len(examples)
except AttributeError:
self._data = {}
finally:
self._data[name] = deque(examples, maxlen=self.capacity)
def append(self, example):
try:
for k, v in example.items():
self._data[k].append(v)
except AttributeError:
self._data = {k: deque([v], maxlen=self.capacity) for k,v in example.items()}
def extend(self, examples):
try:
for k, v in examples.items():
self._data[k].extend(v)
except AttributeError:
self._data = {k: deque(v, maxlen=self.capacity) for k,v in examples.items()}
def sample(self, size):
idx = random_index(len(self), size, replace=False)
return {k: [col[i] for i in idx] for k,col in self._data.items()}