-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
136 lines (102 loc) · 2.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from random import choice
from flask import Flask, jsonify, request
from uuid import uuid4
import json
from sus import SusClient
from math import exp
from collections import deque
import base64
with open('config.json', 'r') as f:
config = json.loads(f.read())
sus_client = SusClient(project=config['project'],
endpoint_id=config['endpoint_id'])
app = Flask(__name__, static_folder='static')
prompt_options = (
{
'prompt': 'duck',
'description': 'duck'
},
{
'prompt': 'apple',
'description': 'apple'
},
{
'prompt': 'amongUs',
'description': 'among us'
},
{
'prompt': 'minion',
'description': 'minion'
},
{
'prompt': 'piano',
'description': 'piano'
},
{
'prompt': 'shark',
'description': 'shark'
},
{
'prompt': 'tree',
'description': 'tree'
},
{
'prompt': 'house',
'description': 'house'
}
)
def get_prompt() -> str:
return choice(prompt_options)
def guid() -> str:
return uuid4().hex
@app.route('/')
def index():
return app.send_static_file("index.html")
@app.route('/multiplayer/')
def multiplayer():
return app.send_static_file("index_multiplayer.html")
@app.route('/api/prompts/new', methods=["POST"])
def create_round_get_prompt():
return jsonify(get_prompt()), 200
def conf_to_score(confidence: float) -> int:
assert 0 <= confidence <= 1, "confidence domain"
return int(101 / (100 * exp(-9.2 * confidence) + 1))
recent_submissions = deque()
@app.route('/api/prompts/recent')
def get_recent_subs():
return jsonify(list(recent_submissions)), 200
@app.route('/api/prompts/<prompt_name>/submit', methods=["POST"])
def make_submission_for_prompt(prompt_name: str):
if request.content_type != 'image/png':
return jsonify({'msg': "expected image/png content"}), 400
data = request.get_data()
result_dict = sus_client.predict(data)
score = 0
if prompt_name in result_dict:
confidence = result_dict[prompt_name]
score = conf_to_score(confidence)
categories = []
for name in result_dict.keys():
for p in prompt_options:
if p['prompt'] == name:
categories.append(p)
response = {
'prompt_id': prompt_name,
'score': score,
'categories': categories
}
desc = ''
for p in prompt_options:
if p['prompt'] == prompt_name:
desc = p['description']
break
recent_submissions.append({
'description': desc,
'score': score,
'img': base64.b64encode(data).decode('utf-8')
})
while len(recent_submissions) > 6:
recent_submissions.popleft()
return jsonify(response), 200
if __name__ == '__main__':
app.run()