Pulse: fix save hilbert timing + figsizing

This commit is contained in:
Eric Teunis de Boone 2023-05-31 17:10:58 +02:00
parent 95594c59ab
commit e9f459b8fa

View file

@ -199,9 +199,9 @@ def read_time_residuals_cache(cache_fname, template_dt, antenna_dt, snr_sigma_fa
if len(ret.shape) > 2: if len(ret.shape) > 2:
return ret[0,:], ret[1,:], ret[2,:] return ret[0,:], ret[1,:], ret[2,:]
elif len(ret.shape) > 1: elif len(ret.shape) > 1:
return ret[0,:], ret[1,:], np.array([np.nan]*len(ret)) return ret[0,:], ret[1,:], np.array([np.nan]*len(ret[0]))
else: else:
return ret[:], np.array([np.nan]*len(ret)), np.array([np.nan]*len(ret)) return ret[:], np.array([np.nan]*len(ret[0])), np.array([np.nan]*len(ret[0]))
except (KeyError, FileNotFoundError): except (KeyError, FileNotFoundError):
return np.array([]), np.array([]), np.array([]) return np.array([]), np.array([]), np.array([])
@ -215,7 +215,8 @@ def write_time_residuals_cache(cache_fname, data, template_dt, antenna_dt, noise
if ds_name in pgroup2.keys(): if ds_name in pgroup2.keys():
del pgroup2[ds_name] del pgroup2[ds_name]
ds = pgroup2.create_dataset(ds_name, (3, len(time_residuals)), dtype='f', maxshape=(None))
ds = pgroup2.create_dataset(ds_name, (3, len(data[0])), dtype='f', maxshape=(None))
ds[0] = data[0] ds[0] = data[0]
ds[1] = data[1] ds[1] = data[1]
ds[2] = data[2] ds[2] = data[2]
@ -472,7 +473,7 @@ def get_time_residuals_for_template(
# Were new time residuals calculated? # Were new time residuals calculated?
# Add them to the cache file # Add them to the cache file
if len(time_residuals) > 1: if len(time_residuals) >= 1:
# merge cached and calculated time residuals # merge cached and calculated time residuals
time_residuals = np.concatenate((cached_time_residuals, time_residuals), axis=None) time_residuals = np.concatenate((cached_time_residuals, time_residuals), axis=None)
snrs = np.concatenate( (cached_snrs, snrs), axis=None) snrs = np.concatenate( (cached_snrs, snrs), axis=None)
@ -495,6 +496,12 @@ if __name__ == "__main__":
if os.name == 'posix' and "DISPLAY" not in os.environ: if os.name == 'posix' and "DISPLAY" not in os.environ:
matplotlib.use('Agg') matplotlib.use('Agg')
if False:
plt.rc('font', size=25)
figsize = (12,12)
bp_freq = (30e-3, 80e-3) # GHz bp_freq = (30e-3, 80e-3) # GHz
interp_template_dt = 5e-5 # ns interp_template_dt = 5e-5 # ns
template_length = 200 # ns template_length = 200 # ns
@ -521,6 +528,9 @@ if __name__ == "__main__":
use_cache = True use_cache = True
write_cache = None # Leave None for default action write_cache = None # Leave None for default action
wrong_peak_condition_multiple = 2
wrong_peak_condition = lambda t_res: abs(t_res) > antenna_dt*wrong_peak_condition_multiple
# #
# Interpolation Template # Interpolation Template
# to create an 'analog' sampled antenna # to create an 'analog' sampled antenna
@ -580,7 +590,6 @@ if __name__ == "__main__":
print()# separating tqdm print()# separating tqdm
print()# separating tqdm print()# separating tqdm
wrong_peak_condition = lambda t_res: abs(t_res) > antenna_dt*4
mask = wrong_peak_condition(time_residuals) mask = wrong_peak_condition(time_residuals)
# Save directly to large data array # Save directly to large data array
@ -677,10 +686,11 @@ if __name__ == "__main__":
# SNR time accuracy plot # SNR time accuracy plot
# #
if True: if True:
enable_threshold_markers = [False, False, True, True]
threshold_markers = ['^', 'v', '8', 'X'] # make sure to have filled markers here threshold_markers = ['^', 'v', '8', 'X'] # make sure to have filled markers here
mask_thresholds = np.array([np.inf, N_residuals*0.5, N_residuals*0.1, 1, 0]) mask_thresholds = np.array([np.inf, N_residuals*0.5, N_residuals*0.1, 1, 0])
fig, ax = plt.subplots() fig, ax = plt.subplots(figsize=figsize)
ax.set_title(f"Template matching SNR vs time accuracy") ax.set_title(f"Template matching SNR vs time accuracy")
ax.set_xlabel("Signal to Noise Factor") ax.set_xlabel("Signal to Noise Factor")
ax.set_ylabel("Time Accuracy [ns]") ax.set_ylabel("Time Accuracy [ns]")
@ -706,13 +716,14 @@ if __name__ == "__main__":
# calculate absolute deviation from the mean # calculate absolute deviation from the mean
residual_mean_deviation = np.sqrt( (time_residuals - mean_residual)**2 ) residual_mean_deviation = np.sqrt( (time_residuals - mean_residual)**2 )
snr_std = np.std(snrs) snr_std = np.std(snrs[valid_mask])
time_accuracy_std = np.std(residual_mean_deviation) time_accuracy_std = np.std(residual_mean_deviation[valid_mask])
scatter_kwargs = dict( scatter_kwargs = dict(
ls='none', ls='none',
marker='.', marker='o',
alpha=0.3, alpha=0.2,
ms=1,
zorder=1.8, zorder=1.8,
) )
@ -780,6 +791,7 @@ if __name__ == "__main__":
# limit y-axis upper limit to 1e1 # limit y-axis upper limit to 1e1
if True: if True:
this_lim = 1e1 this_lim = 1e1
if ax.get_ylim()[1] >= this_lim:
ax.set_ylim([None, this_lim]) ax.set_ylim([None, this_lim])
# require y-axis lower limit to be at least 1e-1 # require y-axis lower limit to be at least 1e-1
@ -796,10 +808,11 @@ if __name__ == "__main__":
if low_ylims <= this_lim: if low_ylims <= this_lim:
ax.set_ylim([this_lim, None]) ax.set_ylim([this_lim, None])
if True: # require y-axis lower limit to be at least 1e-1 # require x-axis lower limit to be under 1e0
low_ylims = ax.get_ylim()[0] if True:
if low_ylims >= 1e-1: this_lim = 1e0
ax.set_ylim([1e-1, None]) if ax.get_xlim()[0] >= this_lim:
ax.set_xlim([this_lim, None])
fig.tight_layout() fig.tight_layout()
if len(template_dts) == 1: if len(template_dts) == 1: