Skip to content

Commit

Permalink
Merge pull request #755 from nu-radio/refactor-fs2
Browse files Browse the repository at this point in the history
Refactor output_writer_hdf5.py. Mostly WS
  • Loading branch information
fschlueter authored Nov 22, 2024
2 parents 6571b68 + 9fab51c commit e65ceeb
Showing 1 changed file with 34 additions and 27 deletions.
61 changes: 34 additions & 27 deletions NuRadioMC/simulation/output_writer_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
# self._n_showers = len(self._fin['shower_ids'])

# self._hdf5_keys = ['event_group_ids', 'xx', 'yy', 'zz', 'vertex_times',
# 'azimuths', 'zemiths', 'energies',
# 'azimuths', 'zemiths', 'energies',
# 'shower_energies', ''n_interaction',
# 'shower_ids', 'event_ids', 'triggered', 'n_samples',
# 'dt', 'Tnoise', 'Vrms', 'bandwidth', 'trigger_names']
Expand Down Expand Up @@ -98,37 +98,40 @@ def add_event_group(self, event_buffer):
"""
logger.debug("adding event group to output file")



# first add atributes to the output file. Attributes of all events should be the same,
# raising an error if not.
# first add atributes to the output file. Attributes of all events should be the same,
# raising an error if not.
trigger_names = []
extent_array_by = 0
for sid in event_buffer:
for eid in event_buffer[sid]:
evt = event_buffer[sid][eid]

# save event attributes
for enum_entry in genattrs:
if(evt.has_generator_info(enum_entry)):
if evt.has_generator_info(enum_entry):
if enum_entry.name not in self._mout_attributes:
self._mout_attributes[enum_entry.name] = evt.get_generator_info(enum_entry)
else: # if the attribute is already present, we need to check if it is the same for all events
assert all(np.atleast_1d(self._mout_attributes[enum_entry.name] == evt.get_generator_info(enum_entry)))

# save station attributes
for stn in evt.get_stations():
# save station attributes
tmp_keys = [[chp.Vrms_NuRadioMC_simulation, "Vrms"], [chp.bandwidth_NuRadioMC_simulation, "bandwidth"]]
for (key_cp, key_hdf5) in tmp_keys:
channel_values = []
for channel in stn.iter_channels(sorted=True):
channel_values.append(channel[key_cp])
station_key_pairs = [[chp.Vrms_NuRadioMC_simulation, "Vrms"], [chp.bandwidth_NuRadioMC_simulation, "bandwidth"]]
for (key_cp, key_hdf5) in station_key_pairs:
channel_values = [channel[key_cp] for channel in stn.iter_channels(sorted=True)]

if key_hdf5 not in self._mout_groups_attributes[sid]:
self._mout_groups_attributes[sid][key_hdf5] = np.array(channel_values)
else:
assert all(np.atleast_1d(self._mout_groups_attributes[sid][key_hdf5] == np.array(channel_values))), f"station {sid} key {key_hdf5} is {self._mout_groups_attributes[sid][key_hdf5]}, but current channel is {np.array(channel_values)}"
assert all(np.atleast_1d(self._mout_groups_attributes[sid][key_hdf5] == np.array(channel_values))), \
f"station {sid} key {key_hdf5} is {self._mout_groups_attributes[sid][key_hdf5]}, but current channel is {np.array(channel_values)}"

for trigger in stn.get_triggers().values():
if trigger.get_name() not in trigger_names:
trigger_names.append(trigger.get_name())
logger.debug(f"extending data structure by trigger {trigger.get_name()} to output file")
extent_array_by += 1

# the available triggers are not available from the start because a certain trigger
# might only trigger a later event. Therefore we need to extend the array
# if we find a new trigger
Expand All @@ -140,6 +143,7 @@ def add_event_group(self, event_buffer):
for i in range(len(self._mout[key])):
logger.debug(f"extending data structure by {extent_array_by} to output file for key {key}")
self._mout[key][i] = self._mout[key][i] + [False] * extent_array_by

for station_id in self._station_ids:
sg = self._mout_groups[station_id]
if keys[0] in sg:
Expand All @@ -156,7 +160,7 @@ def add_event_group(self, event_buffer):
evt = event_buffer[sid][eid]
if self._particle_mode:
for shower in evt.get_sim_showers():
if not shower.get_id() in shower_ids:
if shower.get_id() not in shower_ids:
logger.debug(f"adding shower {shower.get_id()} to output file")
# shower ids might not be in increasing order. We need to sort the hdf5 output later
shower_ids.append(shower.get_id())
Expand Down Expand Up @@ -185,7 +189,9 @@ def add_event_group(self, event_buffer):
self.__first_event = False
else: # emitters have different properties, so we need to treat them differently than showers
for emitter in evt.get_sim_emitters():
if not emitter.get_id() in shower_ids: # the key "shower_ids" is also used for emitters and identifies the emitter id. This is done because it is the only way to have the same input files for both shower/particle and emitter simulations.
# the key "shower_ids" is also used for emitters and identifies the emitter id. This is done because it is
# the only way to have the same input files for both shower/particle and emitter simulations.
if emitter.get_id() not in shower_ids:
logger.debug(f"adding shower {emitter.get_id()} to output file")
# shower ids might not be in increasing order. We need to sort the hdf5 output later
shower_ids.append(emitter.get_id())
Expand All @@ -204,8 +210,9 @@ def add_event_group(self, event_buffer):
self.__add_parameter(self._mout, keyname, emitter[key], self.__first_event)

self.__first_event = False

# now save station data
stn = evt.get_station() # there can only ever be one station per event! If there are more than one station, this line will crash.
stn = evt.get_station() # there can only ever be one station per event! If there are more than one station, this line will crash.
sg = self._mout_groups[sid]
self.__add_parameter(sg, 'event_group_ids', evt.get_run_number())
self.__add_parameter(sg, 'event_ids', evt.get_id())
Expand All @@ -232,8 +239,8 @@ def add_event_group(self, event_buffer):

self.__add_parameter(sg, 'triggered', stn.has_triggered())

# depending on the simulation mode we have either showers or emitters but we can
# treat them the same way as long as we only call common member functions such as
# depending on the simulation mode we have either showers or emitters but we can
# treat them the same way as long as we only call common member functions such as
# `get_id()`
iterable = None
if self._particle_mode:
Expand All @@ -250,9 +257,9 @@ def add_event_group(self, event_buffer):
# we need to save data per shower, channel and ray tracing solution. Due to the simple table structure
# of the hdf5 files we need to preserve the ordering of the showers and channels. As the order in the
# NuRadio data structure is different, we need to go through some effort to get the right order.
# The shower ids will be sorted at the very end.
# The channel ids already have the correct ordering.
# The ray tracing solutions are also ordered, because the efield object contains the correct ray tracing solution id.
# The shower ids will be sorted at the very end.
# The channel ids already have the correct ordering.
# The ray tracing solutions are also ordered, because the efield object contains the correct ray tracing solution id.
channel_rt_data = {}
keys_channel_rt_data = ['travel_times', 'travel_distances']
if self._mout_attributes['config']['speedup']['amp_per_ray_solution']:
Expand All @@ -266,7 +273,7 @@ def add_event_group(self, event_buffer):
# important: we need to loop over the channels of the station object, not
# the channels present in the sim_station object. This is because the sim
# channel object only contains the channels that have a signal, i.e., a ray
# tracing solution and a strong enough Askaryan signal. But we want to loop over all
# tracing solution and a strong enough Askaryan signal. But we want to loop over all
# channels of the station, because we want to save the data for all channels, not only
# the ones that have a signal. This also preserves the order of the channels.
for iCh, channel in enumerate(stn.iter_channels(sorted=True)):
Expand Down Expand Up @@ -310,9 +317,9 @@ def add_event_group(self, event_buffer):
trigger_times = np.ones((len(shower_ids_stn), len(self._mout_attributes['trigger_names'])), dtype=float) * np.nan
for eid in event_buffer[sid]:
evt = event_buffer[sid][eid]
stn = evt.get_station() # there can only ever be one station per event! If there are more than one station, this line will crash.
# depending on the simulation mode we have either showers or emitters but we can
# treat them the same way as long as we only call common member functions such as
stn = evt.get_station() # there can only ever be one station per event! If there are more than one station, this line will crash.
# depending on the simulation mode we have either showers or emitters but we can
# treat them the same way as long as we only call common member functions such as
# `get_id()`
iterable = None
if self._particle_mode:
Expand Down Expand Up @@ -372,7 +379,7 @@ def add_event_group(self, event_buffer):

if self._particle_mode:
# we also want to save the first interaction even if it didn't contribute to any trigger
# this is important to know the initial neutrino properties (only relevant for the simulation of
# this is important to know the initial neutrino properties (only relevant for the simulation of
# secondary interactions)
stn_buffer = event_buffer[self._station_ids[0]]
evt = stn_buffer[list(stn_buffer.keys())[0]]
Expand Down Expand Up @@ -522,7 +529,7 @@ def calculate_Veff(self):
float: The calculated effective volume (Veff)
"""
# calculate effective
try: # sometimes not all relevant attributes are set, e.g. for emitter simulations.
try: # sometimes not all relevant attributes are set, e.g. for emitter simulations.
triggered = remove_duplicate_triggers(self._mout['triggered'], self._mout['event_group_ids'])
n_triggered = np.sum(triggered)
n_triggered_weighted = np.sum(np.array(self._mout['weights'])[triggered])
Expand Down

0 comments on commit e65ceeb

Please sign in to comment.