Skip to content

Commit

Permalink
Merge branch 'gwastro:master' into multiband_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
WuShichao authored Oct 30, 2023
2 parents e39902b + f85f8bf commit f8cb244
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
40 changes: 30 additions & 10 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy, h5py, logging, argparse, numpy.random
import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
from pycbc import transforms
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc.distributions import read_params_from_config
from pycbc.distributions.utils import draw_samples_from_config, prior_from_config
from scipy.stats import gaussian_kde
Expand All @@ -45,6 +46,8 @@ parser.add_argument('--approximant', required=True,
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
help='size of waveform buffer in seconds')
parser.add_argument('--max-signal-length', type= float,
help="When specified, it cuts the maximum length of the waveform model to the lengh provided")
parser.add_argument('--sample-rate', default=2048, type=float,
help='sample rate in seconds')
parser.add_argument('--low-frequency-cutoff', default=20.0, type=float)
Expand Down Expand Up @@ -269,15 +272,28 @@ class GenUniformWaveform(object):
self.md = q._data[-100:]
self.md2 = q._data[0:100]

def generate(self, **kwds):
def generate(self, **kwds):
kwds.update(fdict)
if kwds['approximant'] in pycbc.waveform.fd_approximants():
ws = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
f_lower=self.f_lower, **kwds)
hp = ws[0]
hc = ws[1]
if args.max_signal_length is not None:
flow = numpy.arange(self.f_lower, 100, .1)[::-1]
length = spa_length_in_time(mass1=kwds['mass1'], mass2=kwds['mass2'], f_lower=flow, phase_order=-1)
maxlen = args.max_signal_length
x = numpy.searchsorted(length, maxlen) - 1
l = length[x]
f = flow[x]
else:
f = self.f_lower

kwds['f_lower'] = f

if kwds['approximant'] in pycbc.waveform.fd_approximants():
hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
f_ref=10.0, **kwds)


if 'fratio' in kwds:
hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])

else:
dt = 1.0 / args.sample_rate
hp = pycbc.waveform.get_waveform_filter(
Expand Down Expand Up @@ -342,10 +358,10 @@ def draw(rtype):
p = bank.keys()
p = [k for k in p if k not in fdict]
p.remove('approximant')
p.remove('f_lower')
if args.input_config is not None:
p = variable_args
bdata = numpy.array([bank.key(k)[-trail:] for k in p])

kde = gaussian_kde(bdata)
points = kde.resample(size=size)
params = {k: v for k, v in zip(p, points)}
Expand Down Expand Up @@ -422,9 +438,10 @@ def cdraw(rtype, ts, te):
return None

return p

tau0s = args.tau0_start
tau0e = tau0s + args.tau0_crawl

go = True

region = 0
Expand All @@ -447,6 +464,7 @@ while tau0s < args.tau0_end:
if r > 10:
conv = uconv
kloop = 0

while ((kloop == 0) or (kconv / okconv) > .5) and len(bank) > 10:
r += 1
kloop += 1
Expand All @@ -455,9 +473,11 @@ while tau0s < args.tau0_end:
bank, kconv = bank.check_params(gen, params, args.minimal_match)
logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s",
region, kloop, r, len(bank), kconv, len(bank) - blen)


if uconv:
logging.info('Ratio of convergences: %2.3f' % (kconv / (uconv)))
logging.info('Progress: {:.0%} completed'.format(tau0s/args.tau0_end))
logging.info('Progress: {:.0%} completed'.format(tau0e/args.tau0_end))

if kloop == 1:
okconv = kconv
Expand All @@ -473,9 +493,9 @@ while tau0s < args.tau0_end:
tau0e += args.tau0_crawl / 2

o = h5py.File(args.output_file, 'w')

for k in bank.keys():
val = bank.key(k)
if val.dtype.char == 'U':
val = val.astype('bytes')
o[k] = val
o['f_lower'] = numpy.ones(len(val)) * args.low_frequency_cutoff
21 changes: 11 additions & 10 deletions bin/pycbc_inspiral
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def template_triggers(t_num):
out_vals['time_index'],
opt.gps_start_time, opt.sample_rate)
#print(idx, out_vals['time_index'])

out_vals_all.append(copy.deepcopy(out_vals))
#print(out_vals_all)
return out_vals_all, tparam
Expand Down Expand Up @@ -358,10 +358,11 @@ with ctx:
if opt.psdvar_segment is not None:
logging.info("Calculating PSD variation")
psd_var = pycbc.psd.calc_filt_psd_variation(gwstrain, opt.psdvar_segment,
opt.psdvar_short_segment, opt.psdvar_long_segment,
opt.psdvar_short_segment, opt.psdvar_long_segment,
opt.psdvar_psd_duration, opt.psdvar_psd_stride,
opt.psd_estimation, opt.psdvar_low_freq, opt.psdvar_high_freq)


if opt.enable_q_transform:
logging.info("Performing q-transform on analysis segments")
q_trans = qtransform.inspiral_qtransform_generator(segments)
Expand Down Expand Up @@ -453,16 +454,16 @@ with ctx:

tsetup = time.time() - tstart
tcheckpoint = time.time()

tanalyze = list(range(tnum_start, len(bank)))
n = opt.finalize_events_template_rate
n = 1 if n is None else n
tchunks = [tanalyze[i:i + n] for i in range(0, len(tanalyze), n)]
tchunks = [tanalyze[i:i + n] for i in range(0, len(tanalyze), n)]

mmap = map
if opt.multiprocessing_nprocesses:
mmap = Pool(opt.multiprocessing_nprocesses).map

for tchunk in tchunks:
data = list(mmap(template_triggers, tchunk))

Expand All @@ -472,11 +473,11 @@ with ctx:
event_mgr.new_template(tmplt=tparam)

for edata in out_vals_all:
event_mgr.add_template_events(names, [edata[n] for n in names])
event_mgr.add_template_events(names, [edata[n] for n in names])

event_mgr.cluster_template_events("time_index", "snr", cluster_window)
event_mgr.finalize_template_events()

if opt.finalize_events_template_rate is not None:
event_mgr.consolidate_events(opt, gwstrain=gwstrain)

Expand All @@ -488,7 +489,7 @@ with ctx:
if opt.checkpoint_exit_maxtime and \
(time.time() - tstart > opt.checkpoint_exit_maxtime):
event_mgr.save_state(max(tchunk), opt.output + '.checkpoint')
sys.exit(opt.checkpoint_exit_code)
sys.exit(opt.checkpoint_exit_code)

event_mgr.consolidate_events(opt, gwstrain=gwstrain)
event_mgr.finalize_events()
Expand Down
4 changes: 0 additions & 4 deletions pycbc/psd/variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pycbc.psd
from pycbc.types import TimeSeries
from pycbc.filter import resample_to_delta_t


def mean_square(data, delta_t, srate, short_stride, stride):
Expand Down Expand Up @@ -113,9 +112,6 @@ def calc_filt_psd_variation(strain, segment, short_segment, psd_long_segment,
# Convert start and end times immediately to floats
start_time = float(strain.start_time)
end_time = float(strain.end_time)

# Resample the data
strain = resample_to_delta_t(strain, 1.0 / 2048)
srate = int(strain.sample_rate)

# Fix the step for the PSD estimation and the time to remove at the
Expand Down

0 comments on commit f8cb244

Please sign in to comment.