-
Notifications
You must be signed in to change notification settings - Fork 10
/
api.py
89 lines (74 loc) · 2.22 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
import data_loader.data_loaders as module_data
import embedding.embedding as module_embedding
import model.model as module_arch
from parse_config import ConfigParser
from utils.util import predict_class_from_text
app = Flask(__name__)
CORS(app)
parser = argparse.ArgumentParser(
description="PyTorch Natural Language Processing Template"
)
parser.add_argument(
"-r",
"--resume",
default=os.path.join(
"saved", "models", "Email-Spam", "0407_192255", "model_best.pth"
),
type=str,
help="path to latest checkpoint (default: None)",
)
parser.add_argument(
"-d",
"--device",
default=None,
type=str,
help="indices of GPUs to enable (default: all)",
)
config = ConfigParser(parser)
args = parser.parse_args()
# TODO: LOAD VOCAB WITHOUT DATA LOADER
# setup data_loader instances
data_loader = getattr(module_data, config["test_data_loader"]["type"])(
config["test_data_loader"]["args"]["data_dir"],
batch_size=32,
seq_length=128,
shuffle=False,
validation_split=0.0,
training=False,
num_workers=2,
)
# build model architecture
try:
config["embedding"]["args"].update({"vocab": data_loader.dataset.vocab})
embedding = config.initialize("embedding", module_embedding)
except:
embedding = None
config["arch"]["args"].update({"vocab": data_loader.dataset.vocab})
config["arch"]["args"].update({"embedding": embedding})
model = config.initialize("arch", module_arch)
checkpoint = torch.load(args.resume)
state_dict = checkpoint["state_dict"]
if config["n_gpu"] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
# prepare model for testing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
@app.route("/predict", methods=["POST"])
def predict():
in_text = request.json["in_text"]
# try:
score, prediction = predict_class_from_text(
model=model, input_text=in_text, dataset=data_loader.dataset
)
out = prediction.item()
score = score.item()
return jsonify({"class": out, "score": score})
if __name__ == "__main__":
app.run("0.0.0.0", port=8080)