Skip to content

Commit

Permalink
Merge pull request #736 from nu-radio/basetrace-fs2
Browse files Browse the repository at this point in the history
Allow to get frequency spectrum from windowed trace.
  • Loading branch information
fschlueter authored Nov 18, 2024
2 parents 0a16f1f + e6ebb1d commit ec245da
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions NuRadioReco/framework/base_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,31 @@ def get_filtered_trace(self, passband, filter_type='butter', order=10, rp=None):
spec *= filter_response
return fft.freq2time(spec, self.get_sampling_rate())

def get_frequency_spectrum(self):
if self.__time_domain_up_to_date:
self._frequency_spectrum = fft.time2freq(self._time_trace, self._sampling_rate)
self._time_trace = None
# logger.debug("frequency spectrum has shape {}".format(self._frequency_spectrum.shape))
self.__time_domain_up_to_date = False
return np.copy(self._frequency_spectrum)
def get_frequency_spectrum(self, window_mask=None):
"""
Returns the frequency spectrum.
Parameters
----------
window_mask: array of bools (default: None)
If not None, specifies the time window to be used for the FFT. Has to have the same length as the trace.
Returns
-------
frequency_spectrum: np.array of floats
The frequency spectrum.
"""
if window_mask is None:
if self.__time_domain_up_to_date:
self._frequency_spectrum = fft.time2freq(self._time_trace, self._sampling_rate)
self._time_trace = None
self.__time_domain_up_to_date = False

return np.copy(self._frequency_spectrum)
else:
trace = copy.copy(self.get_trace())
# The double transpose allows to work with 1D and ND traces
return fft.time2freq(trace.T[window_mask].T, self._sampling_rate)

def set_trace(self, trace, sampling_rate):
"""
Expand Down Expand Up @@ -177,8 +195,26 @@ def add_trace_start_time(self, start_time):
def get_trace_start_time(self):
return self._trace_start_time

def get_frequencies(self):
return get_frequencies(self.get_number_of_samples(), self._sampling_rate)
def get_frequencies(self, window_mask=None):
"""
Returns the frequencies of the frequency spectrum.
Parameters
----------
window_mask: array of bools (default: None)
If not None, used to determine the number of samples in the time domain used for the frequency spectrum.
Returns
-------
frequencies: np.array of floats
The frequencies of the frequency spectrum.
"""
if window_mask is None:
nsamples = self.get_number_of_samples()
else:
nsamples = int(np.sum(window_mask))

return get_frequencies(nsamples, self._sampling_rate)

def get_hilbert_envelope(self):
from scipy import signal
Expand Down Expand Up @@ -449,4 +485,4 @@ def __truediv__(self, x):

@functools.lru_cache(maxsize=1024)
def get_frequencies(length, sampling_rate):
return np.fft.rfftfreq(length, d=1. / sampling_rate)
return np.fft.rfftfreq(length, d=1. / sampling_rate)

0 comments on commit ec245da

Please sign in to comment.