-
Notifications
You must be signed in to change notification settings - Fork 0
/
entry.py
74 lines (63 loc) · 2.64 KB
/
entry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import importlib
import json
import os
import numpy as np
config = json.load(open("configs/deepLogMF.json"))
class Runner:
def __init__(self, config):
self.config = config
self.epochs = config["epochs"]
Algorithm_Class = importlib.import_module(
"algorithm."+config["algorithm"]["name"])
self.algorithm = Algorithm_Class.Algorithm(config["algorithm"])
self._data_preprocess(
config["algorithm"]["N"], config["algorithm"]["M"])
self.train_trace = []
def _data_preprocess(self, users_num, items_num):
# if data haven't been pre-processed (convert into matrices)
if not os.path.exists("np_dataset"):
os.makedirs("np_dataset")
if not os.path.exists("np_dataset/dataR.dat"):
f = open("dataset/ml-1m/ratings.dat", "w")
dataset = [r.split("::") for r in f.readlines()]
dataset = np.stack(dataset).astype(np.int32)
users, items, ratings = dataset[:, 0], dataset[:, 1], dataset[:, 2]
eval_idx = np.random.choice(
len(users),
int(round(0.1*len(users))),
replace=False,
)
dataR = np.zeros([users_num, items_num])
evalR = np.zeros([users_num, items_num])
for idx in range(len(users)):
if idx in eval_idx:
evalR[users[idx]-1, items[idx]-1] = 1
else:
dataR[users[idx]-1, items[idx]-1] = 1
np.savetxt("np_dataset/dataR.dat", dataR)
np.savetxt("np_dataset/evalR.dat", evalR)
# load the pre-processed data (as matrices)
self.dataR = np.loadtxt("np_dataset/dataR.dat")
self.evalR = np.loadtxt("np_dataset/evalR.dat")
def train(self):
for step in range(self.epochs):
# train
self.algorithm.train(self.dataR)
mpr = self.algorithm.eval(self.evalR)
self.train_trace.append(mpr)
print("MPR@epoch{}: {:.8}".format(str(step+1).zfill(3), mpr))
self.save()
def save(self):
if not os.path.exists("checkpoint"):
os.makedirs("checkpoint")
if not os.path.exists("checkpoint/"+self.config["algorithm"]["name"]):
os.makedirs("checkpoint/"+self.config["algorithm"]["name"])
self.algorithm.save("checkpoint/"+self.config["algorithm"]["name"])
np.savetxt(
os.path.join(
"checkpoint/"+self.config["algorithm"]["name"], "trace.txt"),
np.array(self.train_trace, dtype=np.float32)
)
if __name__ == "__main__":
runner = Runner(config)
runner.train()