Pulse: move timeresidual matching to function

This commit is contained in:
Eric Teunis de Boone 2023-04-26 15:45:42 +02:00
parent 1f00a3fe76
commit 168b0a60bc
1 changed files with 227 additions and 211 deletions

View File

@ -172,6 +172,227 @@ def create_template(dt=1, timelength=1, bp_freq=(0, np.inf), name=None, normalis
return template, _deltapeak
def get_time_residuals_for_template(
N_residuals, template, interpolation_template=None,
antenna_dt=1, antenna_timelength=100,
snr_sigma_factor=10,bp_freq=(0,np.inf),
normalise_noise=False, h5_cache_fname=None, read_cache=True, write_cache=None,
rng=rng, tqdm=tqdm,
):
# Read in cached time residuals
if read_cache:
cached_time_residuals = read_time_residuals_cache(h5_cache_fname, template.dt, antenna_dt, snr_sigma_factor)
else:
cached_time_residuals = np.array([])
#
# Find difference between true and templated times
#
time_residuals = np.zeros(max(0, (N_residuals - len(cached_time_residuals))))
for j in tqdm(range(len(time_residuals))):
do_plots = j==0
# receive at antenna
## place the deltapeak signal at a random location
antenna = Waveform(None, dt=antenna_dt, name='Signal')
if interpolation_template is None: # Create antenna trace without interpolation template
antenna_true_signal, antenna_peak_sample = util.deltapeak(timelength=antenna_timelength, samplerate=1/antenna.dt, offset=[0.2, 0.8], rng=rng)
antenna.peak_sample = antenna_peak_sample
antenna.peak_time = antenna.dt * antenna.peak_sample
antenna.signal = antenna_bp(antenna.signal, *bp_freq, antenna.dt)
print(f"Antenna Peak Time: {antenna.peak_time}")
print(f"Antenna Peak Sample: {antenna.peak_sample}")
else: # Sample the interpolation template at some offset
antenna.peak_time = antenna_timelength * ((0.8 - 0.2) *rng.random(1) + 0.2)
sampling_offset = rng.random(1)*antenna.dt
antenna.t = util.sampled_time(1/antenna.dt, start=0, end=antenna_timelength)
# Sample the interpolation template
antenna.signal = interpolation_template.interpolate(antenna.t - antenna.peak_time)
antenna.peak_sample = antenna.peak_time/antenna.dt
antenna_true_signal = antenna.signal
true_time_offset = antenna.peak_time - template.peak_time
if False: # flip polarisation
antenna.signal *= -1
## Add noise
noise_amplitude = max(template.signal) * 1/snr_sigma_factor
noise_realisation = noise_amplitude * white_noise_realisation(len(antenna.signal), normalise=normalise_noise)
filtered_noise = antenna_bp(noise_realisation, *bp_freq, antenna.dt)
antenna.signal += filtered_noise
# Show signals
if do_plots:
fig, axs = plt.subplots(2, sharex=True)
axs[0].set_title("Antenna Waveform")
axs[-1].set_xlabel("Time [ns]")
axs[0].set_ylabel("Amplitude")
axs[0].plot(antenna.t, antenna.signal, label='bandpassed w/ noise', alpha=0.9)
axs[0].plot(antenna.t, antenna.signal - filtered_noise, label='bandpassed w/o noise', alpha=0.9)
axs[0].legend()
axs[1].set_title("Template")
axs[1].set_ylabel("Amplitude")
axs[1].plot(template.t, template.signal, label='orig')
axs[1].plot(template.t + true_time_offset, template.signal, label='true moved orig')
axs[1].legend()
axs[0].grid()
axs[1].grid()
fig.savefig('figures/11_antenna_signals.pdf')
if True: # zoom
wx = 100
x0 = true_time_offset
old_xlims = axs[0].get_xlim()
axs[0].set_xlim( x0-wx, x0+wx)
fig.savefig('figures/11_antenna_signals_zoom.pdf')
# restore
axs[0].set_xlim(*old_xlims)
if True:
plt.close(fig)
axs2 = None
if True: # upsampled trace
upsampled_trace, upsampled_t = trace_upsampler(antenna.signal, template.t, antenna.t)
if do_plots: # Show upsampled traces
fig2, axs2 = plt.subplots(1, sharex=True)
if not hasattr(axs2, '__len__'):
axs2 = [axs2]
axs2[-1].set_xlabel("Time [ns]")
axs2[0].set_ylabel("Amplitude")
axs2[0].plot(antenna.t, antenna.signal, marker='o', label='orig')
axs2[0].plot(upsampled_t, upsampled_trace, label='upsampled')
axs2[0].legend(loc='upper right')
fig2.savefig('figures/11_upsampled.pdf')
wx = 1e2
x0 = upsampled_t[0] + wx - 5
axs2[0].set_xlim(x0-wx, x0+wx)
fig2.savefig('figures/11_upsampled_zoom.pdf')
if True:
plt.close(fig2)
# determine correlations with arguments
lag_dt = upsampled_t[1] - upsampled_t[0]
corrs, (out1_signal, out2_template, lags) = my_correlation(upsampled_trace, template.signal)
# Determine best correlation time
idx = np.argmax(abs(corrs))
best_sample_lag = lags[idx]
best_time_lag = best_sample_lag * lag_dt
else: # downsampled template
raise NotImplementedError
corrs, (_, _, lags) = my_downsampling_correlation(antenna.signal, antenna.t, template.signal, template.t)
lag_dt = upsampled_t[1] - upsampled_t[0]
# Calculate the time residual
time_residuals[j] = best_time_lag - true_time_offset
if not do_plots:
continue
if do_plots and axs2:
axs2[-1].axvline(best_time_lag, color='r', alpha=0.5, linewidth=2)
axs2[-1].axvline(true_time_offset, color='g', alpha=0.5, linewidth=2)
# Show the final signals correlated
if do_plots:
# amplitude scaling required for single axis plotting
template_amp_scaler = max(abs(template.signal)) / max(abs(antenna.signal))
# start the figure
fig, axs = plt.subplots(2, sharex=True)
ylabel_kwargs = dict(
#rotation=0,
ha='right',
va='center'
)
axs[-1].set_xlabel("Time [ns]")
offset_list = [
[best_time_lag, dict(label=template.name, color='orange')],
[true_time_offset, dict(label='True offset', color='green')],
]
# Signal
i=0
axs[i].set_ylabel("Amplitude", **ylabel_kwargs)
axs[i].plot(antenna.t, antenna.signal, label=antenna.name)
# Plot the template
for offset_args in offset_list:
this_kwargs = offset_args[1]
offset = offset_args[0]
l = axs[i].plot(offset + template.t, template_amp_scaler * template.signal, **this_kwargs)
axs[i].legend()
# Correlation
i=1
axs[i].set_ylabel("Correlation", **ylabel_kwargs)
axs[i].plot(lags * lag_dt, corrs)
# Lines across both axes
for offset_args in offset_list:
this_kwargs = offset_args[1]
offset = offset_args[0]
for i in [0,1]:
axs[i].axvline(offset, ls='--', color=this_kwargs['color'], alpha=0.7)
axs[0].axvline(offset + len(template.signal) * (template.t[1] - template.t[0]), color=this_kwargs['color'], alpha=0.7)
if True: # zoom
wx = len(template.signal) * (template.dt)/2
t0 = best_time_lag
old_xlims = axs[0].get_xlim()
axs[i].set_xlim( x0-wx, x0+3*wx)
fig.savefig('figures/11_corrs_zoom.pdf')
# restore
axs[i].set_xlim(*old_xlims)
fig.tight_layout()
fig.savefig('figures/11_corrs.pdf')
if True:
plt.close(fig)
# Were new time residuals calculated?
# Add them to the cache file
if len(time_residuals) > 1:
# merge cached and calculated time residuals
time_residuals = np.concatenate((cached_time_residuals, time_residuals), axis=None)
if write_cache or read_cache and write_cache is None: # write the cache
write_time_residuals_cache(h5_cache_fname, time_residuals, template_dt, antenna_dt, snr_sigma_factor)
else:
time_residuals = cached_time_residuals
# Only return N_residuals (even if more have been cached)
return time_residuals[:N_residuals]
if __name__ == "__main__":
import os
import matplotlib
@ -192,7 +413,7 @@ if __name__ == "__main__":
[10, 20, 30, 50],
[100, 200, 300, 500]
),
axis=None)
axis=None, dtype=float)
antenna_dt = 2 # ns
antenna_timelength = 1024 # ns
@ -245,223 +466,18 @@ if __name__ == "__main__":
time_accuracies = np.zeros(len(snr_factors))
mask_counts = np.zeros(len(snr_factors))
for k, snr_sigma_factor in tqdm(enumerate(snr_factors)):
# Read in cached time residuals
if True:
cached_time_residuals = read_time_residuals_cache(h5_cache_fname, template.dt, antenna_dt, snr_sigma_factor)
else:
cached_time_residuals = np.array([])
#
# Find difference between true and templated times
#
time_residuals = np.zeros(max(0, (N_residuals - len(cached_time_residuals))))
for j in tqdm(range(len(time_residuals))):
do_plots = j==0
# receive at antenna
## place the deltapeak signal at a random location
antenna = Waveform(None, dt=antenna_dt, name='Signal')
if False: # Create antenna trace without interpolation template
antenna_true_signal, antenna_peak_sample = util.deltapeak(timelength=antenna_timelength, samplerate=1/antenna.dt, offset=[0.2, 0.8], rng=rng)
antenna.peak_sample = antenna_peak_sample
antenna.peak_time = antenna.dt * antenna.peak_sample
antenna.signal = antenna_bp(antenna.signal, *bp_freq, antenna.dt)
print(f"Antenna Peak Time: {antenna.peak_time}")
print(f"Antenna Peak Sample: {antenna.peak_sample}")
else: # Sample the interpolation template at some offset
antenna.peak_time = antenna_timelength * ((0.8 - 0.2) *rng.random(1) + 0.2)
sampling_offset = rng.random(1)*antenna.dt
antenna.t = util.sampled_time(1/antenna.dt, start=0, end=antenna_timelength)
# Sample the interpolation template
antenna.signal = interp_template.interpolate(antenna.t - antenna.peak_time)
antenna.peak_sample = antenna.peak_time/antenna.dt
antenna_true_signal = antenna.signal
true_time_offset = antenna.peak_time - template.peak_time
if False: # flip polarisation
antenna.signal *= -1
## Add noise
noise_amplitude = max(template.signal) * 1/snr_sigma_factor
noise_realisation = noise_amplitude * white_noise_realisation(len(antenna.signal), normalise=normalise_noise)
filtered_noise = antenna_bp(noise_realisation, *bp_freq, antenna.dt)
antenna.signal += filtered_noise
if do_plots: # show signals
fig, axs = plt.subplots(2, sharex=True)
axs[0].set_title("Antenna Waveform")
axs[-1].set_xlabel("Time [ns]")
axs[0].set_ylabel("Amplitude")
axs[0].plot(antenna.t, antenna.signal, label='bandpassed w/ noise', alpha=0.9)
axs[0].plot(antenna.t, antenna.signal - filtered_noise, label='bandpassed w/o noise', alpha=0.9)
axs[0].legend()
axs[1].set_title("Template")
axs[1].set_ylabel("Amplitude")
axs[1].plot(template.t, template.signal, label='orig')
axs[1].plot(template.t + true_time_offset, template.signal, label='true moved orig')
axs[1].legend()
axs[0].grid()
axs[1].grid()
fig.savefig('figures/11_antenna_signals.pdf')
if True: # zoom
wx = 100
x0 = true_time_offset
old_xlims = axs[0].get_xlim()
axs[0].set_xlim( x0-wx, x0+wx)
fig.savefig('figures/11_antenna_signals_zoom.pdf')
# restore
axs[0].set_xlim(*old_xlims)
if True:
plt.close(fig)
axs2 = None
if True: # upsampled trace
upsampled_trace, upsampled_t = trace_upsampler(antenna.signal, template.t, antenna.t)
if do_plots: # Show upsampled traces
fig2, axs2 = plt.subplots(1, sharex=True)
if not hasattr(axs2, '__len__'):
axs2 = [axs2]
axs2[-1].set_xlabel("Time [ns]")
axs2[0].set_ylabel("Amplitude")
axs2[0].plot(antenna.t, antenna.signal, marker='o', label='orig')
axs2[0].plot(upsampled_t, upsampled_trace, label='upsampled')
axs2[0].legend(loc='upper right')
fig2.savefig('figures/11_upsampled.pdf')
wx = 1e2
x0 = upsampled_t[0] + wx - 5
axs2[0].set_xlim(x0-wx, x0+wx)
fig2.savefig('figures/11_upsampled_zoom.pdf')
if True:
plt.close(fig2)
# determine correlations with arguments
lag_dt = upsampled_t[1] - upsampled_t[0]
corrs, (out1_signal, out2_template, lags) = my_correlation(upsampled_trace, template.signal)
# Determine best correlation time
idx = np.argmax(abs(corrs))
best_sample_lag = lags[idx]
best_time_lag = best_sample_lag * lag_dt
else: # downsampled template
raise NotImplementedError
corrs, (_, _, lags) = my_downsampling_correlation(antenna.signal, antenna.t, template.signal, template.t)
lag_dt = upsampled_t[1] - upsampled_t[0]
# Calculate the time residual
time_residuals[j] = best_time_lag - true_time_offset
if not do_plots:
continue
if do_plots and axs2:
axs2[-1].axvline(best_time_lag, color='r', alpha=0.5, linewidth=2)
axs2[-1].axvline(true_time_offset, color='g', alpha=0.5, linewidth=2)
# Show the final signals correlated
if do_plots:
# amplitude scaling required for single axis plotting
template_amp_scaler = max(abs(template.signal)) / max(abs(antenna.signal))
# start the figure
fig, axs = plt.subplots(2, sharex=True)
ylabel_kwargs = dict(
#rotation=0,
ha='right',
va='center'
)
axs[-1].set_xlabel("Time [ns]")
offset_list = [
[best_time_lag, dict(label=template.name, color='orange')],
[true_time_offset, dict(label='True offset', color='green')],
]
# Signal
i=0
axs[i].set_ylabel("Amplitude", **ylabel_kwargs)
axs[i].plot(antenna.t, antenna.signal, label=antenna.name)
# Plot the template
for offset_args in offset_list:
this_kwargs = offset_args[1]
offset = offset_args[0]
l = axs[i].plot(offset + template.t, template_amp_scaler * template.signal, **this_kwargs)
axs[i].legend()
# Correlation
i=1
axs[i].set_ylabel("Correlation", **ylabel_kwargs)
axs[i].plot(lags * lag_dt, corrs)
# Lines across both axes
for offset_args in offset_list:
this_kwargs = offset_args[1]
offset = offset_args[0]
for i in [0,1]:
axs[i].axvline(offset, ls='--', color=this_kwargs['color'], alpha=0.7)
axs[0].axvline(offset + len(template.signal) * (template.t[1] - template.t[0]), color=this_kwargs['color'], alpha=0.7)
if True: # zoom
wx = len(template.signal) * (template.dt)/2
t0 = best_time_lag
old_xlims = axs[0].get_xlim()
axs[i].set_xlim( x0-wx, x0+3*wx)
fig.savefig('figures/11_corrs_zoom.pdf')
# restore
axs[i].set_xlim(*old_xlims)
fig.tight_layout()
fig.savefig('figures/11_corrs.pdf')
if True:
plt.close(fig)
time_residuals = get_time_residuals_for_template(
N_residuals, template, interpolation_template=interp_template,
antenna_dt=antenna_dt, antenna_timelength=antenna_timelength,
snr_sigma_factor=snr_sigma_factor, bp_freq=bp_freq, normalise_noise=normalise_noise,
h5_cache_fname=h5_cache_fname, rng=rng, tqdm=tqdm)
print()# separating tqdm
print()# separating tqdm
# Were new time residuals calculated?
# Add them to the cache file
if len(time_residuals) > 1:
# merge cached and calculated time residuals
time_residuals = np.concatenate((cached_time_residuals, time_residuals), axis=None)
if True: # write the cache
write_time_residuals_cache(h5_cache_fname, time_residuals, template_dt, antenna_dt, snr_sigma_factor)
else:
time_residuals = cached_time_residuals
# Make a plot of the time residuals
if N_residuals > 1:
time_residuals = time_residuals[:N_residuals]
for i in range(1 + cut_wrong_peak_matches):
mask_count = 0