Faster template correlation for pulsed_timing

This commit is contained in:
Eric Teunis de Boone 2023-04-19 20:35:53 +02:00
parent b415786806
commit bca924cdc2

View file

@ -65,24 +65,46 @@ def antenna_bp(trace, low_bp, high_bp, dt, order=3):
return bandpassed
def my_correlation(in1, template):
#
in1_long = np.zeros( (len(in1)+2*len(template)) )
in1_long[len(template):-len(template)] = in1
def my_correlation(in1, template, lags=None):
template_length = len(template)
in1_length = len(in1)
# fill the template with zeros and copy template
template_long = np.zeros_like(in1_long)
template_long[len(template):2*len(template)] = template
lags = np.arange(-len(template), len(in1) ) - len(template)
if lags is None:
lags = np.arange(-template_length+1, in1_length + 1)
# do the correlation jig
corrs = np.zeros_like(lags, dtype=float)
for i, l in enumerate(lags):
lagged_template = np.roll(template_long, l)
corrs[i] = np.dot(lagged_template, in1_long)
if l <= 0: # shorten template at the front
in1_start = 0
template_end = template_length
return corrs, (in1_long, template_long, lags)
template_start = -template_length - l
in1_end = max(0, min(in1_length, -template_start)) # 0 =< l + template_length =< in1_lengt
elif l > in1_length - template_length:
# shorten template from the back
in1_end = in1_length
template_start = 0
in1_start = min(l, in1_length)
template_end = max(0, in1_length - l)
else:
in1_start = min(l, in1_length)
in1_end = min(in1_start + template_length, in1_length)
# full template
template_start = 0
template_end = template_length
# Slice in1 and template
in1_slice = in1[in1_start:in1_end]
template_slice = template[template_start:template_end]
corrs[i] = np.dot(in1_slice, template_slice)
return corrs, (in1, template, lags)
def trace_upsampler(template_signal, trace, template_t, trace_t):
template_dt = template.t[1] - template.t[0]
@ -297,7 +319,7 @@ if __name__ == "__main__":
axs[i].axvline(offset, ls='--', **this_kwargs)
if True: # zoom
wx = len(template.signal) * (template.t[1] - template.t[0])/2
wx = len(template.signal) * (template.dt)/2
t0 = best_time_lag
old_xlims = axs[0].get_xlim()