Skip to content

Commit

Permalink
small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
karenterveer authored and MijnheerD committed Nov 21, 2024
1 parent cccd171 commit 6ac4656
Showing 1 changed file with 54 additions and 19 deletions.
73 changes: 54 additions & 19 deletions NuRadioReco/modules/LOFAR/pipelineVisualizer_LOFAR.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import radiotools.helper as hp
import scipy
import radiotools
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize

from scipy.signal import resample

Expand Down Expand Up @@ -67,10 +69,12 @@ def plot_polarization(self, event, detector):

from NuRadioReco.framework.electric_field import get_stokes



fig_pol, ax = plt.subplots(figsize=(8,7))
fcs = ['black', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'grey', 'cyan', 'magenta', 'yellow', 'red']

triggered_station_ids = [station.get_id() for station in event.get_stations() if station.get_parameter(stationParameters.triggered)]
num_stations = len(triggered_station_ids)
cmap = get_cmap('jet')
norm = Normalize(vmin=0, vmax=num_stations-1)

lora_core = event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.core)

Expand All @@ -96,7 +100,7 @@ def plot_polarization(self, event, detector):
station_pos_vB = cs.transform_to_vxB_vxvxB(station_pos, core=core)[0]
station_pos_vvB = cs.transform_to_vxB_vxvxB(station_pos, core=core)[1]

ax.scatter(station_pos_vB, station_pos_vvB, color=fcs[i], s=20, label=f'Station CS{station.get_id():03d}')
ax.scatter(station_pos_vB, station_pos_vvB, color=cmap(norm(i)), s=20, label=f'Station CS{station.get_id():03d}')

for field in efields:

Expand Down Expand Up @@ -156,19 +160,19 @@ def plot_polarization(self, event, detector):
dx_sigma_minus = pol_degree * np.cos(pol_angle - pol_angle_sigma)
dy_sigma_minus = pol_degree * np.sin(pol_angle - pol_angle_sigma)

ax.arrow(pos_vB, pos_vvB, dx_sigma_plus, dy_sigma_plus, head_width=2, head_length=5, fc=fcs[i], ec = fcs[i], alpha=0.5)
ax.arrow(pos_vB, pos_vvB, dx_sigma_minus, dy_sigma_minus, head_width=2, head_length=5, ec = fcs[i], fc=fcs[i], alpha=0.5)
ax.arrow(pos_vB, pos_vvB, dx, dy, head_width=2, head_length=6, fc=fcs[i], ec = fcs[i])
ax.arrow(pos_vB, pos_vvB, dx_sigma_plus, dy_sigma_plus, head_width=2, head_length=5, fc=cmap(norm(i)), ec = cmap(norm(i)), alpha=0.5)
ax.arrow(pos_vB, pos_vvB, dx_sigma_minus, dy_sigma_minus, head_width=2, head_length=5, ec = cmap(norm(i)), fc=cmap(norm(i)), alpha=0.5)
ax.arrow(pos_vB, pos_vvB, dx, dy, head_width=2, head_length=6, fc=cmap(norm(i)), ec = cmap(norm(i)))

if (core != lora_core).all():
lora_vB = cs.transform_to_vxB_vxvxB(lora_core, core=core)[0]
lora_vvB = cs.transform_to_vxB_vxvxB(lora_core, core=core)[1]
ax.scatter(lora_vB, lora_vvB, color='tab:olive', s=50, label='LORA core', marker = 'x')
ax.scatter(lora_vB, lora_vvB, color='tab:red', s=50, label='LORA core', marker = 'x')
label = 'radio core'
else:
label = 'LORA core'

ax.scatter([0], [0], color='red', s=50, label=label, marker = 'x')
ax.scatter([0], [0], color='black', s=50, label=label, marker = 'x')
ax.legend()
ax.set_xlabel('Direction along $v \\times B$ [m]')
ax.set_ylabel('Direction along $v \\times (v \\times B)$ [m]')
Expand Down Expand Up @@ -196,7 +200,13 @@ def show_direction_plot(self, event):
fig_dir, ax = plt.subplots(subplot_kw={'projection': 'polar'})
ax.set_theta_zero_location('E')
ax.set_theta_direction(1)
for station in event.get_stations():

triggered_station_ids = [station.get_id() for station in event.get_stations() if station.get_parameter(stationParameters.triggered)]
num_stations = len(triggered_station_ids)
cmap = get_cmap('jet')
norm = Normalize(vmin=0, vmax=num_stations-1)

for i, station in enumerate(event.get_stations()):
if station.get_parameter(stationParameters.triggered):
zenith = station.get_parameter(stationParameters.cr_zenith)
azimuth = station.get_parameter(stationParameters.cr_azimuth)
Expand All @@ -205,7 +215,9 @@ def show_direction_plot(self, event):
label=f'Station CS{station.get_id():03d}',
marker='P',
markersize=7,
linestyle='')
linestyle='',
color=cmap(norm(i))
)

ax.plot(event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.azimuth),
event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.zenith),
Expand Down Expand Up @@ -244,7 +256,6 @@ def show_time_fluence_plot(self, event, detector, min_number_good_antennas=4):
# plot the antenna positions and mark arrival time by color and "fluence" by markersize.
# Also indicate the reconstructed arrival direction per station via an arrow.

from matplotlib.colors import Normalize
from astropy.time import Time

time = detector.get_detector_time().utc
Expand All @@ -253,6 +264,7 @@ def show_time_fluence_plot(self, event, detector, min_number_good_antennas=4):
self.logger.warning("Event was before Dec 1, 2012. The non-core station clocks might be off.")

good_antennas_dict = {}

for station in event.get_stations():
if station.get_parameter(stationParameters.triggered):
good_antennas_dict[station.get_id()] = []
Expand All @@ -277,6 +289,12 @@ def show_time_fluence_plot(self, event, detector, min_number_good_antennas=4):
good_antennas_dict[station.get_id()].append(channel.get_id())

fig_time, ax = plt.subplots(dpi=150, figsize=(8, 5))

triggered_station_ids = [station.get_id() for station in event.get_stations() if station.get_parameter(stationParameters.triggered)]
num_stations = len(triggered_station_ids)
cmap = get_cmap('jet')
norm = Normalize(vmin=0, vmax=num_stations-1)

fluences = []
positions = []
SNRs = []
Expand Down Expand Up @@ -307,7 +325,13 @@ def show_time_fluence_plot(self, event, detector, min_number_good_antennas=4):
if len(good_antennas) >= min_number_good_antennas:
for channel_id in good_antennas:
timelags.append(station.get_channel(channel_id).get_parameter(channelParameters.signal_time))


for i, station in enumerate(event.get_stations()):
# plot absolute station positions
if station.get_parameter(stationParameters.triggered):
station_pos = detector.get_absolute_position(station.get_id())
ax.scatter(station_pos[0], station_pos[1], color=cmap(norm(i)), s=20, label=f'Station CS{station.get_id():03d}')

timelags = np.array(timelags)
timelags -= timelags[0] # get timelags wrt 1st antenna
# plot all locations and use arrival time for color and fluence for marker size and add a colorbar
Expand All @@ -322,17 +346,19 @@ def show_time_fluence_plot(self, event, detector, min_number_good_antennas=4):
s=15 * fluence_norm(fluences),
cmap='viridis',
zorder=-1)

ax.set_aspect('equal')
plt.colorbar(sc, label='Relative arrival time [ns]', shrink=0.7)
ax.set_xlabel('Meters east [m]')
ax.set_ylabel('Meters north [m]')
plt.legend()
plt.title("Antenna positions and arrival time")

return fig_time


@register_run()
def run(self, event, detector, polarization=False, direction=False):
def run(self, event, detector, save_dir='.', polarization=False, direction=False):
"""
Produce pipeline plots for the given event.
Expand All @@ -342,19 +368,28 @@ def run(self, event, detector, polarization=False, direction=False):
The event for which to visualize the pipeline.
detector : Detector object
The detector for which to visualize the pipeline.
save_dir : str, optional
The directory to save the plots to. Default is the
current directory.
"""

plots = []
if polarization:
plots.append(self.plot_polarization(event, detector))
pol_plot = self.plot_polarization(event, detector)
plots.append(pol_plot)
pol_plot.savefig(f'{save_dir}/polarization_plot_{event.get_id()}.png')

if direction:
plots.append(self.show_direction_plot(event))
plots.append(self.show_time_fluence_plot(event, detector))
dir_plot = self.show_direction_plot(event)
plots.append(dir_plot)
dir_plot.savefig(f'{save_dir}/direction_plot_{event.get_id()}.png')

time_fluence_plot = self.show_time_fluence_plot(event, detector)
plots.append(time_fluence_plot)
time_fluence_plot.savefig(f'{save_dir}/time_fluence_plot_{event.get_id()}.png')

self.plots = [plot for plot in plots]

plt.savefig('/vol/astro7/lofar/kterveer/projects/pipeline/scripts/plots.png')
plt.show()


def end(self):
Expand Down

0 comments on commit 6ac4656

Please sign in to comment.