Pulse: move thresholding plot code around

This commit is contained in:
Eric Teunis de Boone 2023-04-28 16:11:08 +02:00
parent 22d96c99bd
commit 80cc240f23

View file

@ -265,7 +265,7 @@ def get_time_residuals_for_template(
plt.close(fig) plt.close(fig)
axs2 = None axs2 = None
if True: # upsampled trace if True: # simple and dumb trace upsampling
upsampled_trace, upsampled_t = trace_upsampler(antenna.signal, template.t, antenna.t) upsampled_trace, upsampled_t = trace_upsampler(antenna.signal, template.t, antenna.t)
if do_plots: # Show upsampled traces if do_plots: # Show upsampled traces
fig2, axs2 = plt.subplots(1, sharex=True) fig2, axs2 = plt.subplots(1, sharex=True)
@ -423,6 +423,7 @@ if __name__ == "__main__":
normalise_noise = False normalise_noise = False
h5_cache_fname = f'11_pulsed_timing.hdf5' h5_cache_fname = f'11_pulsed_timing.hdf5'
use_cache = True
# #
# Interpolation Template # Interpolation Template
@ -474,7 +475,7 @@ if __name__ == "__main__":
N_residuals, template, interpolation_template=interp_template, N_residuals, template, interpolation_template=interp_template,
antenna_dt=antenna_dt, antenna_timelength=antenna_timelength, antenna_dt=antenna_dt, antenna_timelength=antenna_timelength,
snr_sigma_factor=snr_sigma_factor, bp_freq=bp_freq, normalise_noise=normalise_noise, snr_sigma_factor=snr_sigma_factor, bp_freq=bp_freq, normalise_noise=normalise_noise,
h5_cache_fname=h5_cache_fname, rng=rng, tqdm=tqdm) h5_cache_fname=h5_cache_fname, rng=rng, tqdm=tqdm, read_cache=use_cache)
print()# separating tqdm print()# separating tqdm
print()# separating tqdm print()# separating tqdm
@ -565,6 +566,9 @@ if __name__ == "__main__":
# SNR time accuracy plot # SNR time accuracy plot
if True: if True:
threshold_markers = ['^', 'v', '8', 'o']
mask_thresholds = [np.inf, N_residuals*0.5, N_residuals*0.1, 1, 0]
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title(f"Template matching SNR vs time accuracy") ax.set_title(f"Template matching SNR vs time accuracy")
ax.set_xlabel("Signal to Noise Factor") ax.set_xlabel("Signal to Noise Factor")
@ -586,10 +590,10 @@ if __name__ == "__main__":
for k, template_dt in enumerate(template_dts): for k, template_dt in enumerate(template_dts):
# indicate masking values # indicate masking values
for j, mask_threshold in enumerate(pairwise([np.inf, 250, 50, 1, 0])): for j, mask_threshold in enumerate(pairwise(mask_thresholds)):
kwargs = dict( kwargs = dict(
ls='none', ls='none',
marker=['^', 'v','8', 'o',][j], marker=threshold_markers[j],
color= None if template_dt_colors[k] is None else template_dt_colors[k] color= None if template_dt_colors[k] is None else template_dt_colors[k]
) )
mask = mask_counts[k] >= mask_threshold[1] mask = mask_counts[k] >= mask_threshold[1]
@ -613,6 +617,9 @@ if __name__ == "__main__":
ax.set_ylim([None, 1e1]) ax.set_ylim([None, 1e1])
fig.tight_layout() fig.tight_layout()
fig.savefig(f"figures/11_time_res_vs_snr_tdt{template_dt:0.1e}.pdf") if len(template_dts) == 1:
fig.savefig(f"figures/11_time_res_vs_snr_tdt{template_dt:0.1e}.pdf")
else:
fig.savefig(f"figures/11_time_res_vs_snr.pdf")
plt.show() plt.show()