-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_lund.py
executable file
·114 lines (98 loc) · 3.94 KB
/
test_lund.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# !/usr/bin/env python3
#
# Frederic Dreyer, BOOST 2018 tutorial
#
# Load a keras model and apply it to sample a of W and dijet images,
# then plot the results.
#
# Usage:
# python3 test_lund.py [--sig file_sig] [--bkg file_bkg]
# [--threshold treshold] [--nev nevents]
# [--model file_model]
#
import keras
import read_lund_json as lund
from matplotlib.colors import LogNorm
import numpy as np
import matplotlib.pyplot as plt
import argparse
parser = argparse.ArgumentParser(description='Plot lund images')
parser.add_argument('--sig', action='store',
default='W-lund-pt2000-parton.json.gz',
dest='file_sig')
parser.add_argument('--bkg', action='store',
default='dijet-lund-pt2000-parton.json.gz',
dest='file_bkg')
parser.add_argument('--model', action='store',
default='W_Conv_Net_lund_pt2000-parton.h5',
dest='file_model')
parser.add_argument('--nev', type=int, default=1000, dest='nev')
parser.add_argument('--treshold', type=float, default=0.5, dest='tresh')
args = parser.parse_args()
# set up the readers
sig_reader = lund.LundImage(args.file_sig, args.nev, 50)
bkg_reader = lund.LundImage(args.file_bkg, args.nev, 50)
# get array from file
print('Reading images from file')
sig_images = np.array(sig_reader.values())
bkg_images = np.array(bkg_reader.values())
# create an image array which contains both the background and signal
# images, which we will then label with the CNN model
images = np.zeros((len(sig_images)+len(bkg_images),1,50,50))
images[:len(bkg_images),0,:,:] = bkg_images
images[len(bkg_images):,0,:,:] = sig_images
# create labels
labels = np.concatenate((np.zeros(len(bkg_images)),np.ones(len(sig_images))))
labels = np.asarray([[1 if x == n else 0 for n in range(2)] for x in labels])
# load the keras model and evaluate on images
print('Loading keras model')
model = keras.models.load_model(args.file_model)
print('Evaluating model on data sample (this step might take some time)')
sig_prob = model.predict(images, verbose=0)[:,1]
# prepare plotting of results
print('Plotting results')
fig=plt.figure(figsize=(14, 11))
# plot the signal lund images
sig_tag = images[sig_prob<args.tresh]
print('Tagged %i out of %i as signal jets' % (len(sig_tag),len(images)))
sig_avg_img = np.transpose(np.average(sig_images,axis=0))
sig_tag_avg = np.transpose(np.average(sig_tag[:,0,:,:],axis=0))
fig.add_subplot(2,2,1)
plt.title('Lund image (W) - truth')
plt.xlabel('$\ln(R / \Delta)$')
plt.ylabel('$\ln(k_t / \mathrm{GeV})$')
plt.imshow(sig_avg_img, origin='lower',
aspect='auto', extent=[0,7,-3,7], cmap=plt.get_cmap('BuPu'),
vmax=0.025,vmin=0.0)
plt.colorbar()
fig.add_subplot(2,2,2)
plt.title('Lund image (W) - tagged')
plt.xlabel('$\ln(R / \Delta)$')
plt.ylabel('$\ln(k_t / \mathrm{GeV})$')
plt.imshow(sig_tag_avg, origin='lower',
aspect='auto', extent=[0,7,-3,7], cmap=plt.get_cmap('BuPu'),
vmax=0.025,vmin=0.0)
plt.colorbar()
# plot the background lund images
bkg_tag = images[sig_prob>args.tresh]
print('Tagged %i out of %i as background jets' % (len(bkg_tag),len(images)))
bkg_avg_img = np.transpose(np.average(bkg_images,axis=0))
bkg_tag_avg = np.transpose(np.average(bkg_tag[:,0,:,:],axis=0))
fig.add_subplot(2,2,3)
plt.title('Lund image (QCD) - truth')
plt.xlabel('$\ln(R / \Delta)$')
plt.ylabel('$\ln(k_t / \mathrm{GeV})$')
plt.imshow(bkg_avg_img, origin='lower',
aspect='auto', extent=[0,7,-3,7], cmap=plt.get_cmap('BuPu'),
vmax=0.025,vmin=0.0)
plt.colorbar()
fig.add_subplot(2,2,4)
plt.title('Lund image (QCD) - tagged')
plt.xlabel('$\ln(R / \Delta)$')
plt.ylabel('$\ln(k_t / \mathrm{GeV})$')
plt.imshow(bkg_tag_avg, origin='lower',
aspect='auto', extent=[0,7,-3,7], cmap=plt.get_cmap('BuPu'),
vmax=0.025,vmin=0.0)
plt.colorbar()
print('Saving figure lund_tagging.png')
plt.savefig('lund_tagging.png',bbox_inches='tight')