-
Notifications
You must be signed in to change notification settings - Fork 1
/
masking.py
executable file
·171 lines (147 loc) · 5.97 KB
/
masking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from math import sqrt
from logger import logger
import torch
import math
def determine_masking_sequence(args, n_heads=0, n_layers=0):
mask_number = args.mask_number
if mask_number is None:
# Compute the number of heads to prune on percentage if needed
mask_number = []
for prune_percent in args.mask_percent:
total_heads = n_heads * n_layers
n_to_mask = int(total_heads * prune_percent / 100)
# Make sure we keep at least one head per layer
if args.min_number_attention_heads > 0:
if n_to_mask > total_heads - args.min_number_attention_heads * n_layers:
n_to_mask = total_heads - args.min_number_attention_heads * n_layers
mask_number.append(n_to_mask)
break
mask_number.append(n_to_mask)
# We'll incrementally prune layers and evaluate
mask_number = sorted(mask_number)
mask_sequence = mask_number[:]
for idx in range(1, len(mask_number)):
mask_sequence[idx] = mask_number[idx] - mask_number[idx-1]
# Verify that the total number of heads pruned stayed the same
assert mask_number[-1] == sum(mask_sequence)
return mask_sequence
def what_to_threshold(
head_importance,
mask_threshold,
heads_to_mask=None,
min_number_attention_heads=0,
n_heads=0,
n_layers=0,
):
for idx, score in head_importance.items():
if idx not in heads_to_mask:
heads_to_mask[idx] = {}
score = (score < mask_threshold).nonzero()
thresholded_heads = [(score[i,0].item(), score[i,1].item()) for i in range(score.size()[0])]
# layer/heads that were already pruned
# Prune the lowest scoring heads
thresholded_heads = [
(layer, head)
for (layer, head) in thresholded_heads
if layer not in heads_to_mask[idx] or head not in heads_to_mask[idx][layer]
]
# Update heads to prune
for layer, head in thresholded_heads:
if layer not in heads_to_mask[idx]:
heads_to_mask[idx][layer] = set()
heads_to_mask[idx][layer].add(head)
return heads_to_mask
def what_to_mask(
head_importance,
n_to_mask,
heads_to_mask=None,
min_number_attention_heads=0,
n_heads=0,
n_layers=0,
reverse=False
):
# Sort heads by score
for idx, score in head_importance.items():
if idx not in heads_to_mask:
heads_to_mask[idx] = {}
heads_and_score = [
((layer, head), score[layer, head])
for layer in range(n_layers)
for head in range(n_heads)
]
heads_and_score = sorted(heads_and_score, key=lambda x: x[1], reverse=reverse)
sorted_heads = [head_and_score[0] for head_and_score in heads_and_score]
# Ensure we don't delete all heads in a layer
if min_number_attention_heads:
# Remove the top scoring head in each layer
to_protect = {l: 0 for l in range(n_layers)}
filtered_sorted_heads = []
for layer, head in reversed(sorted_heads):
if layer in to_protect:
if to_protect[layer] < min_number_attention_heads:
to_protect[layer] += 1
continue
else:
to_protect.pop(layer)
filtered_sorted_heads.insert(0, (layer, head))
sorted_heads = filtered_sorted_heads
# layer/heads that were already pruned
# Prune the lowest scoring heads
sorted_heads = [
(layer, head)
for (layer, head) in sorted_heads
if layer not in heads_to_mask[idx] or head not in heads_to_mask[idx][layer]
]
# Update heads to prune
for layer, head in sorted_heads[:n_to_mask]:
if layer not in heads_to_mask[idx]:
heads_to_mask[idx][layer] = set()
heads_to_mask[idx][layer].add(head)
return heads_to_mask
def what_to_mask_iterative(
head_importance,
n_to_mask,
curr_n_to_mask,
heads_to_mask=None,
min_number_attention_heads=0,
n_heads=0,
n_layers=0
):
# Sort heads by score
for idx, score in head_importance.items():
if idx not in heads_to_mask:
heads_to_mask[idx] = {}
heads_and_score = [
((layer, head), score[layer, head])
for layer in range(n_layers)
for head in range(n_heads)
]
heads_and_score = sorted(heads_and_score, key=lambda x: x[1])
sorted_heads = [head_and_score[0] for head_and_score in heads_and_score]
# Ensure we don't delete all heads in a layer
if min_number_attention_heads:
# Remove the top scoring head in each layer
to_protect = {l: 0 for l in range(n_layers)}
filtered_sorted_heads = []
for layer, head in reversed(sorted_heads):
if layer in to_protect:
if to_protect[layer] < min_number_attention_heads:
to_protect[layer] += 1
continue
else:
to_protect.pop(layer)
filtered_sorted_heads.insert(0, (layer, head))
sorted_heads = filtered_sorted_heads
# layer/heads that were already pruned
# Prune the lowest scoring heads
sorted_heads = [
(layer, head)
for (layer, head) in sorted_heads
if layer not in heads_to_mask[idx] or head not in heads_to_mask[idx][layer]
]
# Update heads to prune
for layer, head in sorted_heads[:curr_n_to_mask]:
if layer not in heads_to_mask[idx]:
heads_to_mask[idx][layer] = set()
heads_to_mask[idx][layer].add(head)
return heads_to_mask