-
Notifications
You must be signed in to change notification settings - Fork 1
/
lrp_new.py
138 lines (104 loc) · 4.99 KB
/
lrp_new.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from innvestigator import InnvestigateModel
from argparse import ArgumentParser
import pickle
import torch
import numpy as np
import torch.utils.data
import pickle
import io
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else:
return super().find_class(module, name)
def load_model_by_index(models_path, index):
models = CPU_Unpickler(open(models_path, 'rb')).load()
model = models.get(index)
if model is not None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
return model
def get_test_dataloader(test_data_path, target_class):
# TODO - create dataloader for the test data
# idk how the data is stored, so someone has to implement this pls
words, vectors = pickle.load(open(test_data_path, "rb"))
X = torch.tensor([])
targets = torch.tensor([])
for cls in vectors.keys():
X = torch.cat((X, torch.tensor(vectors[cls])))
if cls == target_class:
targets = torch.cat((targets, torch.tensor([1] * len(vectors[cls]))))
else:
targets = torch.cat((targets, torch.tensor([0] * len(vectors[cls]))))
# create a dataloader from the data
loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, targets), batch_size=64, shuffle=True)
return loader
# reference: https://github.com/moboehle/Pytorch-LRP/blob/master/MNIST%20example.ipynb
def do_lrp(model, dataloader, device):
inn_model = InnvestigateModel(model, lrp_exponent=2,
method="e-rule",
beta=.5)
# assume all predictions happen for one class / from one binary model
evidence_for_class = []
for data, target in dataloader:
data, target = data.to(device), target.to(device)
batch_size = int(data.size()[0])
model_prediction, true_relevance = inn_model.innvestigate(in_tensor=data)
# Below is code from the reference above:
# for i in range(10):
# # Unfortunately, we had some issue with freeing pytorch memory, therefore
# # we need to reevaluate the model separately for every class.
# model_prediction, input_relevance_values = inn_model.innvestigate(in_tensor=data, rel_for_class=i)
# evidence_for_class.append(input_relevance_values)
# since ours deals with individual models for each class, need to figure out how to determine which class it is
# for now, rel_for_class is None => the 'winning' class is used for indexing
model_pred, input_rel_values = inn_model.innvestigate(in_tensor=data)
evidence_for_class.append(input_rel_values)
# evidence_for_class = np.array([elt.numpy() for elt in evidence_for_class])
# the example contained plots of relevance visualization on MNIST dataset
# idk how to visualize it for the data we have
return evidence_for_class
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--models_path", type=str, required=True, help="Path to the serialized file containing models.")
parser.add_argument("--index", type=int, required=True, help="Index of the model to load.")
parser.add_argument("--test_data_path", type=str, required=True, help="Filepath for the test dataset.")
args = parser.parse_args()
if not os.path.exists(args.models_path):
with open('arch1_parallel_training_sharding.pkl', 'rb') as f:
interval_dict = pickle.load(f)
models = {}
interval = 25
for k in interval_dict.keys():
start = k
end = start + interval
if not torch.cuda.is_available():
shard = CPU_Unpickler(open(f'models_{start}_{end}.pkl', 'rb')).load()
else:
shard = pickle.load(open(f'models_{start}_{end}.pkl', 'rb'))
for key in shard.keys():
if shard[key] is not None:
models[key] = shard[key].to('cpu')
# Save the models dictionary to a file
with open(args.models_path, 'wb') as f:
pickle.dump(models, f)
k = args.index
# load the model using the provided path and index
model = load_model_by_index(args.models_path, args.index)
if model is not None:
print("Model loaded successfully.")
else:
print(f"No model found at index {args.index}")
dataloader = get_test_dataloader(args.test_data_path, k)
if dataloader is not None:
print("Dataloader loaded successfully.")
else:
print(f"Incorrect dataloader filepath.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
evidence = do_lrp(model, dataloader, device)
evidence_path = f'evidence_{args.index}.pkl'
with open(evidence_path, 'wb') as f:
pickle.dump(evidence, f)
print(f"Evidence saved to {evidence_path}.")