Skip to content

Commit

Permalink
Add some more functions, fixed the import/export functionality. Added…
Browse files Browse the repository at this point in the history
… some safeguards. Added some docstrings
  • Loading branch information
fschlueter committed Aug 30, 2023
1 parent 0d0935f commit 87eae32
Showing 1 changed file with 114 additions and 103 deletions.
217 changes: 114 additions & 103 deletions NuRadioReco/detector/RNO_G/rnog_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,16 @@ def __init__(self, database_connection='RNOG_test_public', log_level=logging.DEB
self.logger = logging.getLogger("rno-g-detector")
self.logger.setLevel(log_level)

# Define default values for parameter not (yet) implemented in DB. Those values are taken for all channels.
self.__default_values = {
"noise_temperature": 300 * units.kelvin,
"sampling_frequency": 3.2 * units.GHz,
"number_of_samples": 2048,
"is_noiseless": False
}

if pickle_file is None:
self._det_imported_from_file = False

self.__db = Database(database_connection=database_connection)
if database_time is not None:
Expand All @@ -93,7 +102,8 @@ def __init__(self, database_connection='RNOG_test_public', log_level=logging.DEB

self._query_all = always_query_entire_description
else:
self._query_all = False
self._query_all = None # specific case for file imported detector descriptions
self._det_imported_from_file = True

import pickle

Expand All @@ -103,18 +113,12 @@ def __init__(self, database_connection='RNOG_test_public', log_level=logging.DEB
self.__buffered_stations = import_dir["data"]
self._time_periods_per_station = import_dir["periods"]
self._time_period_index_per_station = {st_id: 1 for st_id in self.__buffered_stations}
self.__default_values = import_dir["default_values"]
else:
self.logger.error(f"{pickle_file} with unknown version.")
raise ReferenceError(f"{pickle_file} with unknown version.")

# Define default values for parameter not (yet) implemented in DB. Those values are taken for all channels.
self.__default_values = {
"noise_temperature": 300 * units.kelvin,
"sampling_frequency": 3.2 * units.GHz,
"number_of_samples": 2048,
"is_noiseless": False
}

# Allow overwriting the hard-coded values
self.__default_values.update(over_write_handset_values)

info = f"Query entire detector description at once: {self._query_all}"
Expand All @@ -139,19 +143,29 @@ def export(self, filename):

periods = {}
for station_id in self.__buffered_stations:

# Remove decommissioned stations from the buffer completely
if self.__buffered_stations[station_id] == {}:
self.__buffered_stations.pop(station_id)
continue

idx = self._time_period_index_per_station[station_id]
if idx == 0 or idx == len(self._time_periods_per_station[station_id]["modification_timestamps"]):
self.logger.error("You try to export a decomissioned station")
periods[station_id] = [self._time_periods_per_station[station_id]["modification_timestamps"][idx],
self._time_periods_per_station[station_id]["modification_timestamps"][idx+1]]

periods[station_id] = {"modification_timestamps":
[self._time_periods_per_station[station_id]["modification_timestamps"][idx - 1],
self._time_periods_per_station[station_id]["modification_timestamps"][idx]]
}

export_dir = {
export_dict = {
"version": 1,
"data": self.__buffered_stations,
"periods": periods
"periods": periods,
"default_values": self.__default_values
}

pickle.dump(export_dir, open(filename, "wb"))
pickle.dump(export_dict, open(filename, "wb"))


def _check_update_buffer(self):
Expand Down Expand Up @@ -231,11 +245,16 @@ def update(self, time):
self.logger.debug(f"Update detector to {time}")

self.__set_detector_time(time)
self.__db.set_detector_time(time)
if not self._det_imported_from_file:
self.__db.set_detector_time(time)

update_buffer_for_station = self._check_update_buffer()
any_update = np.any([v for v in update_buffer_for_station.values()])

if self._det_imported_from_file and any_update:
self.logger.error("You have imported the detector description from a pickle file but it is not valid anymore. Full stop!")
raise ValueError("You have imported the detector description from a pickle file but it is not valid anymore. Full stop!")

if any_update:
for key in self.__buffered_stations:
if update_buffer_for_station[station_id]:
Expand Down Expand Up @@ -305,10 +324,24 @@ def has_station(self, station_id):

def _query_station_information(self, station_id, all=True):
"""
Query information about a specific station from the database via the db_mongo_read interface.
You can query only information from the station_list collection (all=False) or the complete
information of the station (all=True).
Parameters
----------
station_id: int
Station id
all: bool
If true, query all relevant information form a station including its channel and devices (position, signal chain, ...).
If false, query only the information from the station list collection (describes a station with all channels and devices
with their (de)commissioning timestamps but not data like position, signal chain, ...)
Returns
-------
None
"""

if station_id in self.__buffered_stations and self.__buffered_stations[station_id] != {}:
Expand Down Expand Up @@ -404,10 +437,6 @@ def __get_channel(self, station_id, channel_id, with_position=False, with_signal
"""
if keys_not_in_dict(self.__buffered_stations, [station_id, "channels", channel_id]):
raise KeyError(f"Could not find channel {channel_id} in detector description for station {station_id}.")

# print(self.__buffered_stations[station_id]["channels"][channel_id].keys())
# print(self.__buffered_stations[station_id]["channels"][channel_id]["signal_ch"].keys())
# print(self.__buffered_stations[station_id]["channels"][channel_id]["id_signal"])

if with_position and keys_not_in_dict(self.__buffered_stations, [station_id, "channels", channel_id, "channel_position"]):

Expand All @@ -427,7 +456,6 @@ def __get_channel(self, station_id, channel_id, with_position=False, with_signal

signal_id = self.__buffered_stations[station_id]["channels"][channel_id]['id_signal']
channel_sig_info = self.__db.get_channel_signal_chain(signal_id)
# remove 'id_measurement' and 'channel_id' object
channel_sig_info.pop('channel_id', None)

self.__buffered_stations[station_id]["channels"][channel_id]['signal_chain'] = channel_sig_info
Expand Down Expand Up @@ -471,7 +499,7 @@ def get_absolute_position(self, station_id):

def get_relative_position(self, station_id, channel_id):
"""
Get the relative position of a specific channel/antenna or device with respect to the station center
Get the relative position of a specific channel/antenna with respect to the station center
Parameters
----------
Expand All @@ -490,44 +518,7 @@ def get_relative_position(self, station_id, channel_id):
"""
channel_info = self.__get_channel(station_id, channel_id, with_position=True)
return channel_info["channel_position"]['position']


def get_relative_position_device(self, station_id, device_id):
"""
Get the relative position of a specific channel/antenna or device with respect to the station center
Parameters
----------
station_id: int
The station id
device_id: str
Device name
Returns
-------
pos: np.array(3,)
3-dim array of relative station position
"""
if keys_not_in_dict(self.__buffered_stations, [station_id, "devices", device_id]):

if "devices" not in self.__buffered_stations[station_id]:
self.__buffered_stations[station_id]["devices"] = collections.OrderedDict()

device_position_information = self.__db.get_collection_information('device_position', station_id)

if device_position_information is None:
return None
else:
#TODO: Check if the code actually works
for ele in device_position_information:
info = ele["measurements"]
self.__buffered_stations[station_id]["devices"][info["device_id"]] = info

return self.__buffered_stations[station_id]["devices"][device_id]["position"]



def get_channel_orientation(self, station_id, channel_id):
"""
Expand Down Expand Up @@ -561,7 +552,7 @@ def get_channel_orientation(self, station_id, channel_id):
rotation = channel_info['channel_position']["rotation"]

return orientation["theta"], orientation["phi"], rotation["theta"], rotation["phi"]


def get_channel_signal_chain(self, station_id, channel_id):
"""
Expand All @@ -583,13 +574,6 @@ def get_channel_signal_chain(self, station_id, channel_id):
describe the signal chain of the channel
"""
channel_info = self.__get_channel(station_id, channel_id, with_signal_chain=True)

print(channel_info.keys())
for key in channel_info:
if isinstance(channel_info[key], dict):
print("\t", key, channel_info[key].keys())
# print(channel_info[key].values())

return self.__buffered_stations[station_id]["channels"][channel_id]["signal_chain"]


Expand Down Expand Up @@ -622,6 +606,67 @@ def get_signal_chain_response(self, station_id, channel_id):
responses.append(Response(value["frequencies"], ydata, value["y-axis_units"], name=key))

return np.prod(responses)


def get_devices(self, station_id):
"""
Get all devices for a particular station.
Parameters
----------
station_id: int
Station id
Returns
-------
devices: dict(str)
Dictonary of all devices with {id: name}.
"""

if not self.has_station(station_id):
self.logger.error(f"Station {station_id} not commissioned at {self.get_detector_time()}. Return empty device list")
return []

return {device["id"]: device["device_name"] for device in self.__buffered_stations[station_id]["devices"].values()}


def get_relative_position_device(self, station_id, device_id):
"""
Get the relative position of a specific device with respect to the station center
Parameters
----------
station_id: int
The station id
device_id: int
Device identifier
Returns
-------
pos: np.array(3,)
3-dim array of relative station position
"""
if keys_not_in_dict(self.__buffered_stations, [station_id, "devices", device_id]):
# All devices should have been queried with _query_station_information
raise KeyError(f"Device {device_id} not in detector description.")

if keys_not_in_dict(self.__buffered_stations, [station_id, "devices", device_id, "device_position"]):

if keys_not_in_dict(self.__buffered_stations, [station_id, "devices", device_id, "id_position"]):
raise KeyError(f"\"id_position\" not in buffer for device {device_id}.")

position_id = self.__buffered_stations[station_id]["devices"][device_id]["id_position"]

device_pos_info = self.__db.get_device_position(device_position_id=position_id)
self.__buffered_stations[station_id]["devices"][device_id]['device_position'] = device_pos_info

return self.__buffered_stations[station_id]["devices"][device_id]["device_position"]["position"]


# def _has_valid_parameter_in_buffer(self, key_list):
# """
Expand Down Expand Up @@ -816,7 +861,6 @@ def __init__(self, frequency, y, y_unit, name="default"):

if y_unit[0] == "dB":
gain = helper.dB_to_linear(y_ampl)
print(gain)
elif y_unit[0] == "mag" or y_unit[0] == "MAG":
gain = y_ampl
else:
Expand All @@ -831,7 +875,6 @@ def __init__(self, frequency, y, y_unit, name="default"):

self.__gains = [interpolate.interp1d(
self.__frequency, gain, kind="linear", bounds_error=False, fill_value=0)]
print(self.__gains[0](0.5))

self.__phases = [interpolate.interp1d(
self.__frequency, y_phase, kind="linear", bounds_error=False, fill_value=0)]
Expand Down Expand Up @@ -903,38 +946,6 @@ def __str__(self):


if __name__ == "__main__":
det = Detector(log_level=logging.DEBUG, over_write_handset_values={"sampling_frequency": 2.4 * units.GHz}, always_query_entire_description=False)
det = Detector(log_level=logging.DEBUG, over_write_handset_values={"sampling_frequency": 2.4 * units.GHz}, always_query_entire_description=True)

det.update(datetime.datetime(2022, 8, 2, 0, 0))
print(det.get_absolute_position(11))
print(det.get_relative_position(11, 0))
print(det.get_channel_orientation(11, 0))
# det.get_channel_signal_chain(11, 0)
res = det.get_signal_chain_response(11, 0)
print(res)

# det._query_station_information(11)

# det.update(datetime.datetime(2022, 8, 2, 0, 0))

# print(det.get_station_ids())
# print(det.get_full_station_information(11))
# det.export("test_det.pickle")

# det.get_full_station_from_buffer(11)
# det.update(datetime.datetime(2022, 8, 2, 0, 0))

# s = det.get_channel_signal_chain(11, 11)
# print(s.keys())

# time_filter = [
# {"$match": {
# 'id': 11,
# 'channels.id': 11,
# 'commission_time': {"$lte": datetime.datetime(2022, 8, 2, 0, 0)},
# 'decommission_time': {"$gte": datetime.datetime(2022, 8, 2, 0, 0)}}}]

# # get all stations which fit the filter
# channel_information = list(det.db().db["station_rnog"].aggregate(time_filter))
# print(len(channel_information[0]["channels"]))

0 comments on commit 87eae32

Please sign in to comment.