-
Notifications
You must be signed in to change notification settings - Fork 6
/
env.py
280 lines (245 loc) · 9.34 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""
Environment for training, validation, and testing
A state consists of a batch of partially parsed sentences
The environment takes a bactch of actions to update the state
It also pre-computes the token embeddings and splits a data batch in to multiple subbatches
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tree import InternalParseNode, Tree
from transition_systems import Action, AttachJuxtapose
from collections import defaultdict
from typing import Dict, Iterator, List, Any, Sequence, Tuple, Optional
from utils import get_device
class EpochEnd(Exception):
"""
Exception raised by Environment.reset() when an epoch ends
"""
pass
class State:
"""
Batched parser state
"""
# a batch of partial trees
partial_trees: Sequence[Tree]
# a list where each element is a list of tokens in a sentence
tokens_word: List[List[str]]
# batch_size x max_len x d_model
tokens_emb: torch.Tensor
# batch_size, the next token's position in the sentence (starting from 0)
next_token_pos: torch.Tensor
# the number of actions executed on the current batch
n_step: int
# batch_size, the index in the current batch
batch_idx: List[int]
def __init__(
self,
partial_trees: List[Tree],
tokens_word: List[List[str]],
tokens_emb: torch.Tensor,
next_token_pos: torch.Tensor,
n_step: int,
batch_idx: List[int],
) -> None:
assert all(next_token_pos[i] <= len(sent) for i, sent in enumerate(tokens_word))
assert (
len(partial_trees)
== len(tokens_word)
== tokens_emb.size(0)
== next_token_pos.size(0)
)
self.partial_trees = partial_trees
self.tokens_word = tokens_word
self.tokens_emb = tokens_emb
self.next_token_pos = next_token_pos
self.n_step = n_step
self.batch_idx = batch_idx
@property
def batch_size(self) -> int:
"Here 'batch' actually means subbatch"
return len(self.partial_trees)
def is_completing(state: State, i: int) -> bool:
"See if the ith partial tree in the state is about to complete"
pos = state.next_token_pos[i].item()
return pos >= len(state.tokens_word[i]) - 1
class Environment:
"Environment for executing actions to update the state"
# data loader
loader: Iterator[Any]
# token encoder
encoder: nn.Module
# CPU or GPU
device: torch.device
# the maximum number of tokens in a subbatch
subbatch_max_tokens: int
state: State
# the entire batch
cached_data: Dict[str, Any]
# the current data subbatch
data_batch: Dict[str, Any]
# the current subbatch is cached_data[_start_idx : _end_idx] in the entire batch
_start_idx: int
_end_idx: int
# predicted trees in the current subbatch
pred_trees: List[Optional[InternalParseNode]]
# ground truth trees in the current subbatch
gt_trees: List[InternalParseNode]
# tensors in a batch
_tensors_to_load = ["tokens_idx", "word_end_mask", "valid_tokens_mask", "tags_idx"]
def __init__(
self,
loader: DataLoader, # type: ignore
encoder: torch.nn.Module,
subbatch_max_tokens: int,
) -> None:
self.loader = iter(loader)
self.encoder = encoder
self.device, _ = get_device()
self.subbatch_max_tokens = subbatch_max_tokens
self.cached_data = {}
self.data_batch = {}
self.pred_trees = []
self.gt_trees = []
def _load_data(self) -> None:
"""
Load a data subbatch to self.data_batch
The loaded data examples range from self._start_idx to self._end_idx in self.cached_data
"""
need_data = (
self.cached_data == {} or self._end_idx >= self.cached_data["num_examples"]
)
if need_data: # need to load another batch
try:
self.cached_data = next(self.loader)
except StopIteration:
raise EpochEnd()
self.cached_data["num_examples"] = len(self.cached_data["tokens_word"])
self._end_idx = 0
self._start_idx = self._end_idx
self._end_idx += 1
# increase self._end_idx until reaching subbatch_max_tokens
lens = self.cached_data["valid_tokens_mask"].detach().sum(dim=-1)
max_len = lens[self._start_idx].item()
while self._end_idx < self.cached_data["num_examples"]:
max_len = max(max_len, lens[self._end_idx].item())
# including dummy tokens due to padding
total_num_tokens = max_len * (self._end_idx - self._start_idx + 1)
if total_num_tokens > self.subbatch_max_tokens:
max_len = lens[self._start_idx : self._end_idx].max().item()
break
self._end_idx += 1
for k, v in self.cached_data.items():
# no data left in self.data_batch
assert k not in self.data_batch or self.data_batch[k] == []
if k == "num_examples":
pass
elif k in Environment._tensors_to_load:
self.data_batch[k] = v[self._start_idx : self._end_idx, :max_len].to(
device=self.device, non_blocking=True
)
else:
self.data_batch[k] = v[self._start_idx : self._end_idx]
def reset(self, force: bool = False) -> State:
"""
Reset a completed state
1. Load a new data subbatch into self.data_batch
2. Run the token encoder on self.data_batch
force: reset even the current subbatch hasn't completed
"""
if force:
self.data_batch = {}
self._load_data() # load some data examples
batch_size = self.data_batch["tokens_idx"].size(0)
# run the token encoder
is_train = self.encoder.training
with torch.set_grad_enabled(is_train):
tokens_emb = self.encoder(
self.data_batch["tokens_idx"],
self.data_batch["tags_idx"],
self.data_batch["valid_tokens_mask"],
self.data_batch["word_end_mask"],
)
# initialize the state
self.state = State(
[None for _ in range(batch_size)],
self.data_batch["tokens_word"],
tokens_emb,
next_token_pos=torch.zeros(
batch_size, dtype=torch.int64, device=self.device
),
n_step=0,
batch_idx=list(range(batch_size)),
)
self.pred_trees = [None for _ in range(batch_size)]
self.gt_trees = self.data_batch["trees"] if "trees" in self.data_batch else None
return self.state
def step(self, actions: List[Action]) -> Tuple[State, bool]:
"""
Execute a subbatch of actions to update the state
"""
batch_size = len(actions)
assert batch_size == len(self.state.partial_trees)
done = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
new_partial_trees = []
new_tokens_word = []
new_next_token_pos = []
new_batch_idx = []
data_batch = defaultdict(list)
for i, action in enumerate(actions):
# execute the ith action
pos = self.state.next_token_pos[i].item()
assert isinstance(pos, int)
tag = self.data_batch["tags"][i][pos]
word = self.data_batch["tokens_word"][i][pos]
tree = AttachJuxtapose.execute(
self.state.partial_trees[i],
action,
pos,
tag,
word,
immutable=False,
)
if is_completing(self.state, i):
# the tree is completed
done[i] = True
assert tree is not None
batch_idx = self.state.batch_idx[i]
assert self.pred_trees[batch_idx] is None
self.pred_trees[batch_idx] = tree
else:
# the tree hasn't completed
new_partial_trees.append(tree)
new_tokens_word.append(self.state.tokens_word[i])
for k, v in self.data_batch.items():
data_batch[k].append(v[i])
new_next_token_pos.append(pos + 1)
new_batch_idx.append(self.state.batch_idx[i])
self.data_batch = dict(data_batch)
self.state.partial_trees = new_partial_trees
self.state.tokens_word = new_tokens_word
self.state.tokens_emb = self.state.tokens_emb[~done]
self.state.next_token_pos = self.state.next_token_pos.new_tensor(
new_next_token_pos
)
self.state.n_step += 1
self.state.batch_idx = new_batch_idx
all_done = done.all().item()
assert isinstance(all_done, bool)
return self.state, all_done
def gt_actions(self) -> List[Action]:
"""
Get the ground truth actions at the current step
"""
return [
action_seq[self.state.n_step].normalize()
for action_seq in self.data_batch["action_seq"]
]
def gt_action_seqs(self) -> List[List[Action]]:
"""
Get all ground truth actions for the current subbatch
"""
return [
[action.normalize() for action in action_seq]
for action_seq in self.data_batch["action_seq"]
]