diff --git a/simulations/11_pulsed_timing.py b/simulations/11_pulsed_timing.py index 9e346f3..7de08de 100755 --- a/simulations/11_pulsed_timing.py +++ b/simulations/11_pulsed_timing.py @@ -156,13 +156,18 @@ def read_time_residuals_cache(cache_fname, template_dt, antenna_dt, snr_sigma_fa ds = pgroup2[ds_name] if N is None: - return deepcopy(ds[:]) + ret = deepcopy(ds[:]) else: - return deepcopy(ds[:min(N, len(ds))]) - except (KeyError, FileNotFoundError): - return np.array([]) + ret = deepcopy(ds[:min(N, len(ds))]) -def write_time_residuals_cache(cache_fname, time_residuals, template_dt, antenna_dt, noise_sigma_factor): + if len(ret.shape) > 1: + return ret[0,:], ret[1,:] + else: + return ret[:], np.array([np.nan]*len(ret)) + except (KeyError, FileNotFoundError): + return np.array([]), np.array([]) + +def write_time_residuals_cache(cache_fname, time_residuals, snrs, template_dt, antenna_dt, noise_sigma_factor): with h5py.File(cache_fname, 'a') as fp: pgroup = fp.require_group('time_residuals') pgroup2 = pgroup.require_group(f'{template_dt}_{antenna_dt}') @@ -172,7 +177,9 @@ def write_time_residuals_cache(cache_fname, time_residuals, template_dt, antenna if ds_name in pgroup2.keys(): del pgroup2[ds_name] - ds = pgroup2.create_dataset(ds_name, (len(time_residuals)), dtype='f', data=time_residuals, maxshape=(None)) + ds = pgroup2.create_dataset(ds_name, (2, len(time_residuals)), dtype='f', maxshape=(None)) + ds[0] = time_residuals + ds[1] = snrs def create_template(dt=1, timelength=1, bp_freq=(0, np.inf), name=None, normalise=False): template = Waveform(None, dt=dt, name=name) @@ -196,15 +203,16 @@ def get_time_residuals_for_template( ): # Read in cached time residuals if read_cache: - cached_time_residuals = read_time_residuals_cache(h5_cache_fname, template.dt, antenna_dt, snr_sigma_factor) + cached_time_residuals, cached_snrs = read_time_residuals_cache(h5_cache_fname, template.dt, antenna_dt, snr_sigma_factor) else: - cached_time_residuals = np.array([]) + cached_time_residuals, cached_snrs = np.array([]), np.array([]) # # Find difference between true and templated times # time_residuals = np.zeros(max(0, (N_residuals - len(cached_time_residuals)))) + snrs = np.zeros_like(time_residuals) for j in tqdm(range(len(time_residuals))): do_plots = j==0 @@ -340,6 +348,7 @@ def get_time_residuals_for_template( # Calculate the time residual time_residuals[j] = best_time_lag - true_time_offset + snrs[j] = antenna.signal_to_noise if not do_plots: continue @@ -419,14 +428,16 @@ def get_time_residuals_for_template( if len(time_residuals) > 1: # merge cached and calculated time residuals time_residuals = np.concatenate((cached_time_residuals, time_residuals), axis=None) + snrs = np.concatenate( (cached_snrs, snrs), axis=None) if write_cache or read_cache and write_cache is None: # write the cache - write_time_residuals_cache(h5_cache_fname, time_residuals, template_dt, antenna_dt, snr_sigma_factor) + write_time_residuals_cache(h5_cache_fname, time_residuals, snrs, template_dt, antenna_dt, snr_sigma_factor) else: time_residuals = cached_time_residuals + snrs = cached_snrs # Only return N_residuals (even if more have been cached) - return time_residuals[:N_residuals] + return time_residuals[:N_residuals], snrs[:N_residuals] if __name__ == "__main__": import os @@ -506,7 +517,7 @@ if __name__ == "__main__": for k, snr_sigma_factor in tqdm(enumerate(snr_factors)): # get the time residuals - time_residuals = get_time_residuals_for_template( + time_residuals, snrs = get_time_residuals_for_template( N_residuals, template, interpolation_template=interp_template, antenna_dt=antenna_dt, antenna_timelength=antenna_timelength, snr_sigma_factor=snr_sigma_factor, bp_freq=bp_freq, normalise_noise=normalise_noise,