Pulse: internal: calculate and cache snr values

This commit is contained in:
Eric Teunis de Boone 2023-05-17 18:23:10 +02:00
parent eac9285efe
commit cc4b545260

View file

@ -156,13 +156,18 @@ def read_time_residuals_cache(cache_fname, template_dt, antenna_dt, snr_sigma_fa
ds = pgroup2[ds_name]
if N is None:
return deepcopy(ds[:])
ret = deepcopy(ds[:])
else:
return deepcopy(ds[:min(N, len(ds))])
except (KeyError, FileNotFoundError):
return np.array([])
ret = deepcopy(ds[:min(N, len(ds))])
def write_time_residuals_cache(cache_fname, time_residuals, template_dt, antenna_dt, noise_sigma_factor):
if len(ret.shape) > 1:
return ret[0,:], ret[1,:]
else:
return ret[:], np.array([np.nan]*len(ret))
except (KeyError, FileNotFoundError):
return np.array([]), np.array([])
def write_time_residuals_cache(cache_fname, time_residuals, snrs, template_dt, antenna_dt, noise_sigma_factor):
with h5py.File(cache_fname, 'a') as fp:
pgroup = fp.require_group('time_residuals')
pgroup2 = pgroup.require_group(f'{template_dt}_{antenna_dt}')
@ -172,7 +177,9 @@ def write_time_residuals_cache(cache_fname, time_residuals, template_dt, antenna
if ds_name in pgroup2.keys():
del pgroup2[ds_name]
ds = pgroup2.create_dataset(ds_name, (len(time_residuals)), dtype='f', data=time_residuals, maxshape=(None))
ds = pgroup2.create_dataset(ds_name, (2, len(time_residuals)), dtype='f', maxshape=(None))
ds[0] = time_residuals
ds[1] = snrs
def create_template(dt=1, timelength=1, bp_freq=(0, np.inf), name=None, normalise=False):
template = Waveform(None, dt=dt, name=name)
@ -196,15 +203,16 @@ def get_time_residuals_for_template(
):
# Read in cached time residuals
if read_cache:
cached_time_residuals = read_time_residuals_cache(h5_cache_fname, template.dt, antenna_dt, snr_sigma_factor)
cached_time_residuals, cached_snrs = read_time_residuals_cache(h5_cache_fname, template.dt, antenna_dt, snr_sigma_factor)
else:
cached_time_residuals = np.array([])
cached_time_residuals, cached_snrs = np.array([]), np.array([])
#
# Find difference between true and templated times
#
time_residuals = np.zeros(max(0, (N_residuals - len(cached_time_residuals))))
snrs = np.zeros_like(time_residuals)
for j in tqdm(range(len(time_residuals))):
do_plots = j==0
@ -340,6 +348,7 @@ def get_time_residuals_for_template(
# Calculate the time residual
time_residuals[j] = best_time_lag - true_time_offset
snrs[j] = antenna.signal_to_noise
if not do_plots:
continue
@ -419,14 +428,16 @@ def get_time_residuals_for_template(
if len(time_residuals) > 1:
# merge cached and calculated time residuals
time_residuals = np.concatenate((cached_time_residuals, time_residuals), axis=None)
snrs = np.concatenate( (cached_snrs, snrs), axis=None)
if write_cache or read_cache and write_cache is None: # write the cache
write_time_residuals_cache(h5_cache_fname, time_residuals, template_dt, antenna_dt, snr_sigma_factor)
write_time_residuals_cache(h5_cache_fname, time_residuals, snrs, template_dt, antenna_dt, snr_sigma_factor)
else:
time_residuals = cached_time_residuals
snrs = cached_snrs
# Only return N_residuals (even if more have been cached)
return time_residuals[:N_residuals]
return time_residuals[:N_residuals], snrs[:N_residuals]
if __name__ == "__main__":
import os
@ -506,7 +517,7 @@ if __name__ == "__main__":
for k, snr_sigma_factor in tqdm(enumerate(snr_factors)):
# get the time residuals
time_residuals = get_time_residuals_for_template(
time_residuals, snrs = get_time_residuals_for_template(
N_residuals, template, interpolation_template=interp_template,
antenna_dt=antenna_dt, antenna_timelength=antenna_timelength,
snr_sigma_factor=snr_sigma_factor, bp_freq=bp_freq, normalise_noise=normalise_noise,