-
Notifications
You must be signed in to change notification settings - Fork 1
/
multichoice.py
54 lines (48 loc) · 1.43 KB
/
multichoice.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from dartboard import encode, get_knn_crosscoder, get_dartboard_crosscoder2, get_dists_crosscoder
def main():
query = 'Do you want to watch soccer?'
texts = [
'Absolutely!',
'Affirmative!',
'I don\'t know!',
'I\'d love to!',
'Maybe later.',
'Maybe!',
'Maybe...',
'No thanks.',
'No way!',
'No, I don\'t wanna do dat.',
'No, thank you!',
'No, thank you.',
'Not right now.',
'Not today.',
'Perhaps..',
'Sure!',
'Yeah!',
'Yes!',
'Yes, please can we?',
'Yes, please!',
'Yes, please.',
'Yes, we ought to!',
'Yes, we shall!',
'Yes, we should!',
]
texts = {i+1: text for i, text in enumerate(texts)}
embs = {i: encode(text) for i, text in texts.items()}
triage = 100
sigma = .5
k = 3
print()
print('Candidates:')
for title, text in texts.items(): print(f' {title:2}: {text}')
print()
print('Query:', query)
results = get_knn_crosscoder(query, embs, encode, texts, k, triage)
print()
print('KNN crosscoder:')
for title in results: print(f' {title:2}: {texts[title]}')
dists = get_dists_crosscoder(query, embs, encode, texts, triage)
results = get_dartboard_crosscoder2(dists, sigma, k)
print()
print('Dartboard crosscoder:')
for title in results: print(f' {title:2}: {texts[title]}')