Pulse: snr plot multiple template_dt curves

This commit is contained in:
Eric Teunis de Boone 2023-04-26 17:04:34 +02:00
parent 168b0a60bc
commit 59feab014e

View file

@ -401,11 +401,14 @@ if __name__ == "__main__":
matplotlib.use('Agg') matplotlib.use('Agg')
bp_freq = (30e-3, 80e-3) # GHz bp_freq = (30e-3, 80e-3) # GHz
template_dt = 5e-2 # ns
interp_template_dt = 5e-5 # ns interp_template_dt = 5e-5 # ns
template_length = 200 # ns template_length = 200 # ns
antenna_dt = 2 # ns
antenna_timelength = 1024 # ns
N_residuals = 50*3 if len(sys.argv) < 2 else int(sys.argv[1]) N_residuals = 50*3 if len(sys.argv) < 2 else int(sys.argv[1])
template_dts = np.array([antenna_dt, 5e-1, 5e-2]) # ns
snr_factors = np.concatenate( # 1/noise_amplitude factor snr_factors = np.concatenate( # 1/noise_amplitude factor
( (
#[0.25, 0.5, 0.75], #[0.25, 0.5, 0.75],
@ -415,8 +418,6 @@ if __name__ == "__main__":
), ),
axis=None, dtype=float) axis=None, dtype=float)
antenna_dt = 2 # ns
antenna_timelength = 1024 # ns
cut_wrong_peak_matches = True cut_wrong_peak_matches = True
normalise_noise = False normalise_noise = False
@ -454,19 +455,21 @@ if __name__ == "__main__":
if True: if True:
plt.close(fig) plt.close(fig)
#
# Create the template
# This is sampled at a lower samplerate than the interpolation template
#
template, _ = create_template(dt=template_dt, timelength=template_length, bp_freq=bp_freq, name='Template')
# #
# Find time accuracies as a function of signal strength # Find time accuracies as a function of signal strength
# #
time_accuracies = np.zeros(len(snr_factors)) time_accuracies = np.zeros((len(template_dts), len(snr_factors)))
mask_counts = np.zeros(len(snr_factors)) mask_counts = np.zeros_like(time_accuracies)
for l, template_dt in tqdm(enumerate(template_dts)):
# Create the template
# This is sampled at a lower samplerate than the interpolation template
template, _ = create_template(dt=template_dt, timelength=template_length, bp_freq=bp_freq, name='Template')
for k, snr_sigma_factor in tqdm(enumerate(snr_factors)): for k, snr_sigma_factor in tqdm(enumerate(snr_factors)):
# get the time residuals
time_residuals = get_time_residuals_for_template( time_residuals = get_time_residuals_for_template(
N_residuals, template, interpolation_template=interp_template, N_residuals, template, interpolation_template=interp_template,
antenna_dt=antenna_dt, antenna_timelength=antenna_timelength, antenna_dt=antenna_dt, antenna_timelength=antenna_timelength,
@ -492,11 +495,10 @@ if __name__ == "__main__":
time_residuals = time_residuals[~mask] time_residuals = time_residuals[~mask]
if not mask_count: if not mask_count:
print("Continuing")
continue continue
time_accuracies[k] = np.std(time_residuals) time_accuracies[l, k] = np.std(time_residuals)
mask_counts[k] = mask_count mask_counts[l, k] = mask_count
hist_kwargs = dict(bins='sqrt', density=False, alpha=0.8, histtype='step') hist_kwargs = dict(bins='sqrt', density=False, alpha=0.8, histtype='step')
fig, ax = plt.subplots() fig, ax = plt.subplots()
@ -567,10 +569,11 @@ if __name__ == "__main__":
ax.set_title(f"Template matching SNR vs time accuracy") ax.set_title(f"Template matching SNR vs time accuracy")
ax.set_xlabel("Signal to Noise Factor") ax.set_xlabel("Signal to Noise Factor")
ax.set_ylabel("Time Accuracy [ns]") ax.set_ylabel("Time Accuracy [ns]")
ax.grid()
ax.legend(title="\n".join([ ax.legend(title="\n".join([
f"N={N_residuals}", f"N={N_residuals}",
f"template_dt={template_dt:0.1e}ns", #f"template_dt={template_dt:0.1e}ns",
f"antenna_dt={antenna_dt:0.1e}ns", f"antenna_dt={antenna_dt:0.1e}ns",
])) ]))
@ -578,28 +581,36 @@ if __name__ == "__main__":
ax.set_xscale('log') ax.set_xscale('log')
ax.set_yscale('log') ax.set_yscale('log')
# plot the values # plot the values per template_dt slice
l = None template_dt_colors = [None]*len(template_dts)
for k, template_dt in enumerate(template_dts):
# indicate masking values
for j, mask_threshold in enumerate(pairwise([np.inf, 250, 50, 1, 0])): for j, mask_threshold in enumerate(pairwise([np.inf, 250, 50, 1, 0])):
kwargs = dict( kwargs = dict(
ls='none', ls='none',
marker=['^', 'v','8', 'o',][j], marker=['^', 'v','8', 'o',][j],
color=None if l is None else l[0].get_color(), color= None if template_dt_colors[k] is None else template_dt_colors[k]
) )
mask = mask_counts >= mask_threshold[1] mask = mask_counts[k] >= mask_threshold[1]
mask &= mask_counts < mask_threshold[0] mask &= mask_counts[k] < mask_threshold[0]
l = ax.plot(snr_factors[mask], time_accuracies[mask], **kwargs) l = ax.plot(snr_factors[mask], time_accuracies[k][mask], **kwargs)
template_dt_colors[k] = l[0].get_color()
if True: # limit y-axis to 1e1 # indicate threshold
ax.set_ylim([None, 1e1]) if True:
ax.axhline(template_dt/np.sqrt(12), ls='--', alpha=0.7, color=template_dt_colors[k], label=f'Template dt:{template_dt:0.1e}ns')
# Set horizontal line at 1 ns # Set horizontal line at 1 ns
if True: if True:
ax.axhline(1, ls='--', alpha=0.8, color='g') ax.axhline(1, ls='--', alpha=0.8, color='g')
ax.grid()
ax.axhline(template_dt/np.sqrt(12), ls='--', alpha=0.7, color='b') ax.legend()
if True: # limit y-axis to 1e1
ax.set_ylim([None, 1e1])
fig.tight_layout() fig.tight_layout()
fig.savefig(f"figures/11_time_res_vs_snr_tdt{template_dt:0.1e}.pdf") fig.savefig(f"figures/11_time_res_vs_snr_tdt{template_dt:0.1e}.pdf")