Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lofar/add pipeline plots #765

Closed
wants to merge 5 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
397 changes: 397 additions & 0 deletions NuRadioReco/modules/LOFAR/pipelineVisualizer_LOFAR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,397 @@
"""
This module contains the pipelineVisualizer class for LOFAR.

.. moduleauthor:: Karen Terveer <[email protected]>
"""

import logging
import numpy as np
import matplotlib.pyplot as plt
import radiotools.helper as hp
import scipy
import radiotools
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize

from scipy.signal import resample

from NuRadioReco.utilities import units
from NuRadioReco.framework.parameters import stationParameters, channelParameters, showerParameters
from NuRadioReco.modules.base.module import register_run


class pipelineVisualizer:
"""
Creates debug plots from the LOFAR pipeline -
This is the pipelineVisualizerTM for LOFAR.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*glorious pipelineVisualizerTM :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you, Philipp


Any significant plots resulting from the pipeline
should be added here by creating a function for them,
and calling all functions sequentially.
"""

def __init__(self):

self.logger = logging.getLogger("NuRadioReco.pipelineVisualizer")


def begin(self, logger_level=logging.NOTSET):

self.__logger_level = logger_level
self.logger.setLevel(logger_level)


def plot_polarization(self, event, detector):

"""
Plot the polarization of the electric field.
This method calculates the stokes parameters of the pulse
using get_stokes from framework.electric_field, and
determines the polarization angle and degree, plotting
them as arrows in the vxB and vxvxB plane.
It estimates uncertainties by picking a pure noise value of
stokes parameters, propagating through the angle and degree
formulas and plotting them as arrows with reduced opacity.
Author: Karen Terveer

Parameters
----------
event : Event object
The event containing the stations and electric fields.
detector : Detector object
The detector object containing information about the detector.

Returns
-------
fig_pol : matplotlib Figure object
The generated figure object containing the polarization plot.
"""

from NuRadioReco.framework.electric_field import get_stokes

fig_pol, ax = plt.subplots(figsize=(8,7))

triggered_station_ids = [station.get_id() for station in event.get_stations() if station.get_parameter(stationParameters.triggered)]
num_stations = len(triggered_station_ids)
cmap = get_cmap('jet')
norm = Normalize(vmin=0, vmax=num_stations-1)

lora_core = event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.core)

try:
core = event.get_first_shower().get_parameter(showerParameters.core)

except:
self.logger.warning("No radio core found, using LORA core instead")
core = lora_core

for i, station in enumerate(event.get_stations()):
#for station in event.get_stations():

if station.get_parameter(stationParameters.triggered):

zenith = station.get_parameter(stationParameters.cr_zenith)
azimuth = station.get_parameter(stationParameters.cr_azimuth)
cs = radiotools.coordinatesystems.cstrafo(
zenith, azimuth, magnetic_field_vector=None, site="lofar")
efields = station.get_electric_fields()

station_pos = detector.get_absolute_position(station.get_id())
station_pos_vB = cs.transform_to_vxB_vxvxB(station_pos, core=core)[0]
station_pos_vvB = cs.transform_to_vxB_vxvxB(station_pos, core=core)[1]

ax.scatter(station_pos_vB, station_pos_vvB, color=cmap(norm(i)), s=20, label=f'Station CS{station.get_id():03d}')

for field in efields:

ids = field.get_channel_ids()
pos = station_pos + detector.get_relative_position(station.get_id(), ids[0])

# transform to vxB and vxvxB, assuming the LORA core reco.
# This is likely NOT the correct core position,
# it has to be determined from the radio data later

pos_vB = cs.transform_to_vxB_vxvxB(pos, core=core)[0]
pos_vvB = cs.transform_to_vxB_vxvxB(pos, core=core)[1]

pulse_window_start, pulse_window_end = station.get_channel(ids[0]).get_parameter(channelParameters.signal_regions)
pulse_window_len = pulse_window_end - pulse_window_start

trace = field.get_trace()[:,pulse_window_start:pulse_window_end]

efield_trace_vxB_vxvxB = cs.transform_to_vxB_vxvxB(
cs.transform_from_onsky_to_ground(trace)
)

#get stokes parameters
stokes = get_stokes(*efield_trace_vxB_vxvxB[:2], window_samples=64)

stokes_max = np.argmax(stokes[0])

I = stokes[0,stokes_max]
Q = stokes[1,stokes_max]
U = stokes[2,stokes_max]
V = stokes[3,stokes_max]

# get stokes uncertainties by picking a pure noise value
I_sigma = stokes[0, stokes_max-pulse_window_len//4]
Q_sigma = stokes[1, stokes_max-pulse_window_len//4]
U_sigma = stokes[2, stokes_max-pulse_window_len//4]
V_sigma = stokes[3, stokes_max-pulse_window_len//4]

pol_angle = 0.5 * np.arctan2(U,Q)
pol_angle_sigma= np.sqrt((U_sigma**2*(0.5*Q/(U**2+Q**2))**2 + Q_sigma**2*(0.5*U/(U**2+Q**2))**2))

# if the polarization deviates from the vxB direction by more than 80 degrees,
# this could indicate something wrong with the antenna. Show a warning including
# the channel ids
if np.abs(0.5 * np.arctan2(U,Q)) > 80*np.pi/180:
self.logger.warning("strange polarization direction in channel group %s" % ids)

pol_degree= np.sqrt(U**2 + Q**2 + V**2) / I
pol_degree *= 7 # scale for better visibility

dx = pol_degree * np.cos(pol_angle)
dy = pol_degree * np.sin(pol_angle)

dx_sigma_plus = pol_degree * np.cos(pol_angle + pol_angle_sigma)
dy_sigma_plus = pol_degree * np.sin(pol_angle + pol_angle_sigma)

dx_sigma_minus = pol_degree * np.cos(pol_angle - pol_angle_sigma)
dy_sigma_minus = pol_degree * np.sin(pol_angle - pol_angle_sigma)

ax.arrow(pos_vB, pos_vvB, dx_sigma_plus, dy_sigma_plus, head_width=2, head_length=5, fc=cmap(norm(i)), ec = cmap(norm(i)), alpha=0.5)
ax.arrow(pos_vB, pos_vvB, dx_sigma_minus, dy_sigma_minus, head_width=2, head_length=5, ec = cmap(norm(i)), fc=cmap(norm(i)), alpha=0.5)
ax.arrow(pos_vB, pos_vvB, dx, dy, head_width=2, head_length=6, fc=cmap(norm(i)), ec = cmap(norm(i)))

if (core != lora_core).all():
lora_vB = cs.transform_to_vxB_vxvxB(lora_core, core=core)[0]
lora_vvB = cs.transform_to_vxB_vxvxB(lora_core, core=core)[1]
ax.scatter(lora_vB, lora_vvB, color='tab:red', s=50, label='LORA core', marker = 'x')
label = 'radio core'
else:
label = 'LORA core'

ax.scatter([0], [0], color='black', s=50, label=label, marker = 'x')
ax.legend()
ax.set_xlabel('Direction along $v \\times B$ [m]')
ax.set_ylabel('Direction along $v \\times (v \\times B)$ [m]')

return fig_pol

def show_direction_plot(self, event):

"""
Create the final plot for the plane wave fit direction
reconstruction. Author: Philipp Laub

Parameters
----------
event : Event object
The event for which to show the final plots.

Returns
-------
fig_dir : matplotlib Figure object
The generated figure object containing the direction plot.
"""

# plot reconstructed directions of all stations and compare to LORA in polar plot:
fig_dir, ax = plt.subplots(subplot_kw={'projection': 'polar'})
ax.set_theta_zero_location('E')
ax.set_theta_direction(1)

triggered_station_ids = [station.get_id() for station in event.get_stations() if station.get_parameter(stationParameters.triggered)]
num_stations = len(triggered_station_ids)
cmap = get_cmap('jet')
norm = Normalize(vmin=0, vmax=num_stations-1)

for i, station in enumerate(event.get_stations()):
if station.get_parameter(stationParameters.triggered):
zenith = station.get_parameter(stationParameters.cr_zenith)
azimuth = station.get_parameter(stationParameters.cr_azimuth)
ax.plot(azimuth,
zenith,
label=f'Station CS{station.get_id():03d}',
marker='P',
markersize=7,
linestyle='',
color=cmap(norm(i))
)

ax.plot(event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.azimuth),
event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.zenith),
label='LORA',
marker="X",
markersize=7,
linestyle='',
color='black')
ax.legend()
plt.title("Reconstructed arrival directions")

return fig_dir

def show_time_fluence_plot(self, event, detector, min_number_good_antennas=4):

"""
Create the final plot for the plane wave fit, including
timing and pseudofluence. Author: Philipp Laub

Parameters
----------
event : Event object
The event for which to show the final plots.
detector : Detector object
The detector for which to show the final plots.
min_number_good_antennas : int, default=4
The minimum number of good antennas that should be
present in a station to consider it for the fit.

Returns
-------
fig_pol : matplotlib Figure object
The generated figure object containing the polarization plot.
"""

# plot the antenna positions and mark arrival time by color and "fluence" by markersize.
# Also indicate the reconstructed arrival direction per station via an arrow.

from astropy.time import Time

time = detector.get_detector_time().utc

if time.mjd < 56266:
self.logger.warning("Event was before Dec 1, 2012. The non-core station clocks might be off.")

good_antennas_dict = {}

for station in event.get_stations():
if station.get_parameter(stationParameters.triggered):
good_antennas_dict[station.get_id()] = []
flagged_channels = station.get_parameter(stationParameters.flagged_channels)
# Get all group IDs which are still present in the station
station_channel_group_ids = set([channel.get_group_id() for channel in station.iter_channels()])

# Get the dominant polarisation orientation as calculated by stationPulseFinder
dominant_orientation = station.get_parameter(stationParameters.cr_dominant_polarisation)

good_channel_pair_ids = np.zeros((len(station_channel_group_ids), 2), dtype=int)
for ind, channel_group_id in enumerate(station_channel_group_ids):
for channel in station.iter_channel_group(channel_group_id):
if np.all(detector.get_antenna_orientation(station.get_id(), channel.get_id()) == dominant_orientation):
good_channel_pair_ids[ind, 0] = channel.get_id()
else:
good_channel_pair_ids[ind, 1] = channel.get_id()

# Check if dominant channel has been flagged
channel = station.get_channel(good_channel_pair_ids[ind, 0])
if channel.get_id() not in flagged_channels:
good_antennas_dict[station.get_id()].append(channel.get_id())

fig_time, ax = plt.subplots(dpi=150, figsize=(8, 5))

triggered_station_ids = [station.get_id() for station in event.get_stations() if station.get_parameter(stationParameters.triggered)]
num_stations = len(triggered_station_ids)
cmap = get_cmap('jet')
norm = Normalize(vmin=0, vmax=num_stations-1)

fluences = []
positions = []
SNRs = []
for station in event.get_stations():
if station.get_parameter(stationParameters.triggered):
zenith = station.get_parameter(stationParameters.cr_zenith)
azimuth = station.get_parameter(stationParameters.cr_azimuth)
good_antennas = good_antennas_dict[station.get_id()]
if len(good_antennas) >= min_number_good_antennas:
for antenna in good_antennas:
positions.append(detector.get_relative_position(station.get_id(), antenna) + detector.get_absolute_position(station.get_id()))
channel = station.get_channel(antenna)
SNRs.append(channel.get_parameter(channelParameters.SNR))
fluences.append(np.sum(np.square(channel.get_trace())))
station_pos = detector.get_absolute_position(station.get_id())
ax.quiver(station_pos[0], station_pos[1],
np.cos(azimuth), np.sin(azimuth),
color='black',
scale=0.02,
scale_units='xy',
angles='uv',
width=0.005)

timelags = []
for station in event.get_stations():
if station.get_parameter(stationParameters.triggered):
good_antennas = good_antennas_dict[station.get_id()]
if len(good_antennas) >= min_number_good_antennas:
for channel_id in good_antennas:
timelags.append(station.get_channel(channel_id).get_parameter(channelParameters.signal_time))

for i, station in enumerate(event.get_stations()):
# plot absolute station positions
if station.get_parameter(stationParameters.triggered):
station_pos = detector.get_absolute_position(station.get_id())
ax.scatter(station_pos[0], station_pos[1], color=cmap(norm(i)), s=20, label=f'Station CS{station.get_id():03d}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a different marker would be good to differentiate between station positions and antenna positions later (square or star?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! (pun intended)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use points!


timelags = np.array(timelags)
timelags -= timelags[0] # get timelags wrt 1st antenna
# plot all locations and use arrival time for color and fluence for marker size and add a colorbar
positions = np.array(positions)
fluences = np.array(fluences)
SNRs = np.array(SNRs)
fluence_norm = Normalize(vmin=np.min(fluences), vmax=np.max(fluences))
sc = ax.scatter(
positions[:,0],
positions[:,1],
c=timelags,
s=15 * fluence_norm(fluences),
cmap='viridis',
zorder=-1)

ax.set_aspect('equal')
plt.colorbar(sc, label='Relative arrival time [ns]', shrink=0.7)
ax.set_xlabel('Meters east [m]')
ax.set_ylabel('Meters north [m]')
plt.legend()
plt.title("Antenna positions and arrival time")

return fig_time


@register_run()
def run(self, event, detector, save_dir='.', polarization=False, direction=False):
"""
Produce pipeline plots for the given event.

Parameters
----------
event : Event object
The event for which to visualize the pipeline.
detector : Detector object
The detector for which to visualize the pipeline.
save_dir : str, optional
The directory to save the plots to. Default is the
current directory.
"""

plots = []
if polarization:
pol_plot = self.plot_polarization(event, detector)
plots.append(pol_plot)
pol_plot.savefig(f'{save_dir}/polarization_plot_{event.get_id()}.png')

if direction:
dir_plot = self.show_direction_plot(event)
plots.append(dir_plot)
dir_plot.savefig(f'{save_dir}/direction_plot_{event.get_id()}.png')

time_fluence_plot = self.show_time_fluence_plot(event, detector)
plots.append(time_fluence_plot)
time_fluence_plot.savefig(f'{save_dir}/time_fluence_plot_{event.get_id()}.png')

self.plots = [plot for plot in plots]



def end(self):

pass