-
Notifications
You must be signed in to change notification settings - Fork 1
/
benchmark_sgc.py
45 lines (40 loc) · 1.85 KB
/
benchmark_sgc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pandas as pd
import unsupervised_benchmark
import semi_supervised_benchmark
import os
import tqdm
home = os.path.expanduser("~")
graph_path_default = os.path.join(home,"benchmark_graphinference","graph")
def benchmark_all():
all_graphs = [f for f in os.listdir(graph_path_default) if "toronto" not in f and os.path.isfile(os.path.join(graph_path_default, f))]
all_dicts = list()
for file in tqdm.tqdm(reversed(sorted(all_graphs)),total=len(all_graphs)):
try:
splitted = file.replace(".gz","").split("_")
dataset = splitted[0]
graph_type = splitted[1]
minmaxscaler = splitted[2]
nn = splitted[3]
normalization = splitted[4]
nnk = splitted[5]
self_loop = splitted[6]
kalofolias = splitted[7]
graph_dict = dict(dataset=dataset,graph_type=graph_type,minmaxscaler=minmaxscaler,nn=nn,normalization=normalization,nnk=nnk,kalofolias=kalofolias,self_loop=self_loop)
acc_train, acc_train_std, acc_test, acc_test_std = 0,0,0,0
try:
acc_train, acc_train_std, acc_test, acc_test_std = semi_supervised_benchmark.run_semi_supervised_benchmark(dataset=dataset,graph_path=os.path.join(graph_path_default,file),minmaxscaler=minmaxscaler,runs=1000,split=20)
except:
print("Error semi supervised {}".format(file))
graph_dict["acc_train"] = acc_train
graph_dict["acc_train_std"] = acc_train_std
graph_dict["acc_test"] = acc_test
graph_dict["acc_test_std"] = acc_test_std
all_dicts.append(graph_dict)
except:
print(file)
continue
df = pd.DataFrame(all_dicts)
df.to_csv("results/sgc1000.csv",index=False)
# print(df)
if __name__ == "__main__":
benchmark_all()