-
Notifications
You must be signed in to change notification settings - Fork 7
/
classical_benchmark.py
74 lines (60 loc) · 1.97 KB
/
classical_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
##########################################################################
#Quantum classifier
#Adrián Pérez-Salinas, Alba Cervera-Lierta, Elies Gil, J. Ignacio Latorre
#Code by APS
#Code-checks by ACL
#June 11th 2019
#Universitat de Barcelona / Barcelona Supercomputing Center/Institut de Ciències del Cosmos
###########################################################################
#This file provides a classical benchmark for the same problem our quantum classifier is tackling
#We use a standard classifier, a SVC by scikit learn
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from data_gen import data_generator
import numpy as np
problem = 'squares'
print(problem)
data, drawing = data_generator(problem) #name for the problem
number = 200
if problem == 'sphere': number = 500
elif problem == 'hypersphere': number = 1000
train_data = data[:number]
test_data = data[number:]
X_train = []
Y_train = []
for d in train_data:
x, y = d
X_train.append(x)
Y_train.append(y)
X_test = []
Y_test = []
for d in test_data:
x, y = d
X_test.append(x)
Y_test.append(y)
text_file_nn = open('classical_benchmark/' + problem + '_nn', mode='w')
text_file_svc = open('classical_benchmark/' + problem + '_svc', mode='w')
nn = 0
svc = 0
for i in range(10):
clf = MLPClassifier(hidden_layer_sizes=(100,),activation='relu',
solver = 'lbfgs')
clf.fit(X_train, Y_train)
pred = clf.predict(X_test)
value_nn = 1 - np.sum(np.abs(pred - Y_test)) / len(Y_test)
text_file_nn.write(str(value_nn))
text_file_nn.write('\n')
if value_nn > nn:
nn = value_nn
clf = SVC(gamma = 'auto')
clf.fit(X_train, Y_train)
pred = clf.predict(X_test)
value_svc = 1 - np.sum(np.abs(pred - Y_test)) / len(Y_test)
text_file_svc.write(str(value_svc))
text_file_svc.write('\n')
if value_svc > svc:
svc = value_svc
print('NN: ', nn)
print('SVC: ', svc)
text_file_nn.close()
text_file_svc.close()