2022-03-10 14:06:17 +01:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
"""
|
|
|
|
Define the TravelSignal class.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import scipy.interpolate as interp
|
|
|
|
|
|
|
|
class TravelSignal:
|
|
|
|
"""
|
|
|
|
Model an arbitrary digitised signal that can be translated to another position and time.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, signal, sample_rate, t_0 = 0, x_0 = 0, periodic=True, interp1d_kw = None):
|
|
|
|
"""
|
|
|
|
Initialise by saving the raw signal
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
signal : arraylike
|
|
|
|
The raw signal to wrap.
|
|
|
|
sample_rate : float
|
|
|
|
Sample rate of the raw signal.
|
|
|
|
t_0 : float, optional
|
|
|
|
Time that this signal is sent out.
|
|
|
|
x_0 : float, optional
|
|
|
|
Location that this signal is sent out from.
|
|
|
|
periodic : bool, optional
|
|
|
|
Translated signal is 0 if it is not periodic
|
|
|
|
and the time/distance is outside the samples.
|
|
|
|
interp1d_kw : bool or dict, optional
|
|
|
|
Use scipy.interpolate's interp1d_kw for interpolation.
|
|
|
|
Set to True, or a dictionary to enable.
|
|
|
|
Dictionary will be entered in as **kwargs.
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.raw = signal
|
|
|
|
self.periodic = periodic
|
|
|
|
|
|
|
|
self.sample_rate = sample_rate # Hz
|
|
|
|
self.sample_length = len(self.raw)
|
|
|
|
self.time_length = self.sample_length*sample_rate # s
|
|
|
|
|
|
|
|
self.x_0 = x_0
|
|
|
|
self.t_0 = t_0
|
|
|
|
|
|
|
|
# choose interpolation method
|
|
|
|
if not interp1d_kw:
|
|
|
|
self.interp_f = None
|
|
|
|
|
|
|
|
# offload interpolation to scipy.interpolate
|
|
|
|
else:
|
|
|
|
interp1d_kw_defaults = {
|
|
|
|
"copy": False,
|
|
|
|
"kind": 'linear',
|
|
|
|
"assume_sorted": True,
|
|
|
|
"bounds_error": True
|
|
|
|
}
|
|
|
|
|
|
|
|
if self.periodic:
|
|
|
|
interp1d_kw_defaults['bounds_error'] = False
|
|
|
|
interp1d_kw_defaults['fill_value'] = (self.raw[-1], self.raw[0])
|
|
|
|
|
|
|
|
# merge kwargs
|
|
|
|
if interp1d_kw is not True:
|
|
|
|
interp1d_kw = { **interp1d_kw_defaults, **interp1d_kw }
|
|
|
|
|
|
|
|
self.interp_f = interp.interp1d(
|
|
|
|
np.arange(0, self.sample_length),
|
|
|
|
self.raw,
|
|
|
|
**interp1d_kw
|
|
|
|
)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.sample_length
|
|
|
|
|
|
|
|
def __call__(self, t_f = None, x_f = None, **kwargs):
|
|
|
|
"""
|
|
|
|
Allow this class to be used as a function.
|
|
|
|
"""
|
|
|
|
return self._translate(t_f, x_f, **kwargs)[0]
|
|
|
|
|
|
|
|
def _translate(self, t_f = None, x_f = None, t_0 = None, x_0 = None, velocity = None):
|
|
|
|
"""
|
|
|
|
Translate the signal from (t_0, x_0) to (t_f, x_f) with optional velocity.
|
|
|
|
|
|
|
|
Returns the signal at (t_f, x_f)
|
|
|
|
"""
|
|
|
|
|
|
|
|
if t_0 is None:
|
|
|
|
t_0 = self.t_0
|
|
|
|
|
|
|
|
if velocity is None:
|
|
|
|
velocity = 1
|
|
|
|
|
|
|
|
|
|
|
|
## spatial offset
|
|
|
|
if x_f is None:
|
|
|
|
spatial_time_offset = 0
|
|
|
|
else:
|
|
|
|
x_f = np.asarray(x_f)
|
|
|
|
if x_0 is None:
|
|
|
|
x_0 = self.x_0
|
|
|
|
|
|
|
|
spatial_time_offset = np.sum(np.sqrt( (x_f - x_0)**2 )/velocity)
|
|
|
|
|
|
|
|
## temporal offset
|
|
|
|
if t_f is None:
|
|
|
|
temporal_time_offset = 0
|
|
|
|
else:
|
|
|
|
t_f = np.asarray(t_f)
|
|
|
|
|
|
|
|
if t_0 is None:
|
|
|
|
t_0 = self.t_0
|
|
|
|
|
|
|
|
temporal_time_offset = t_f - t_0
|
|
|
|
|
|
|
|
# total offset
|
2022-03-10 14:58:33 +01:00
|
|
|
total_time_offset = temporal_time_offset - spatial_time_offset
|
2022-03-10 14:06:17 +01:00
|
|
|
n_offset = (total_time_offset * self.sample_rate )
|
|
|
|
|
|
|
|
# periodic signal
|
|
|
|
if self.periodic:
|
|
|
|
n_offset = n_offset % self.sample_length
|
|
|
|
|
|
|
|
# non-periodic and possibly outside the bounds
|
|
|
|
else:
|
|
|
|
# this is a numpy array
|
|
|
|
if hasattr(n_offset, 'ndim') and n_offset.ndim > 0:
|
|
|
|
mask_idx = np.nonzero( (0 > n_offset) | (n_offset >= self.sample_length) )
|
|
|
|
|
|
|
|
n_offset[mask_idx] = 0
|
|
|
|
|
|
|
|
# not a numpy array
|
|
|
|
else:
|
|
|
|
# outside the bounds
|
|
|
|
if 0 > n_offset or n_offset > self.sample_length:
|
|
|
|
n_offset = np.nan
|
|
|
|
|
|
|
|
# n_offset is invalid
|
|
|
|
# set amplitude to zero
|
|
|
|
if n_offset is np.nan:
|
|
|
|
amplitude = 0
|
|
|
|
|
|
|
|
# n_offset is valid, interpolate the amplitude
|
|
|
|
else:
|
|
|
|
# offload to scipy interpolation
|
|
|
|
if self.interp_f:
|
|
|
|
amplitude = self.interp_f(n_offset)
|
|
|
|
|
|
|
|
# self written linear interpolation
|
|
|
|
else:
|
|
|
|
n_offset = np.asarray(n_offset)
|
|
|
|
|
|
|
|
n_offset_eps, n_offset_int = np.modf(n_offset)
|
|
|
|
n_offset_int = n_offset.astype(int)
|
|
|
|
|
|
|
|
if True:
|
|
|
|
amplitude = (1-n_offset_eps) * self.raw[n_offset_int] \
|
|
|
|
+ n_offset_eps * self.raw[(n_offset_int + 1) % self.sample_length]
|
|
|
|
|
|
|
|
# use nearest value instead of interpolation
|
|
|
|
else:
|
|
|
|
amplitude = self.raw[n_offset_int]
|
|
|
|
|
|
|
|
if not self.periodic:
|
|
|
|
if hasattr(amplitude, 'ndim'):
|
|
|
|
amplitude[mask_idx] = 0
|
|
|
|
|
|
|
|
return amplitude, total_time_offset
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from scipy.stats import norm
|
|
|
|
|
|
|
|
sample_rate = 3e2 # Hz
|
|
|
|
|
|
|
|
t_offset = 8
|
|
|
|
periodic = False
|
|
|
|
|
|
|
|
time = t_offset + np.arange(0, 1, 1/sample_rate) #s
|
|
|
|
time2 = t_offset + np.arange(-1.5, 1, 1/sample_rate) #s
|
|
|
|
|
|
|
|
signal = norm.pdf(time, time[len(time)//2], (time[-1] - time[0])/10)
|
|
|
|
|
|
|
|
mysignal = TravelSignal(signal, sample_rate, t_0 = t_offset, periodic=True)
|
|
|
|
mysignal2 = TravelSignal(signal, sample_rate, t_0 = t_offset, periodic=False)
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(16,4))
|
|
|
|
ax.set_title("Raw and TravelSignal")
|
|
|
|
ax.set_ylabel("Amplitude")
|
|
|
|
ax.set_xlabel("Time")
|
|
|
|
|
|
|
|
ax.plot(time, signal, label='Raw signal')
|
|
|
|
ax.plot(time2, mysignal(time2)+0.5, '.-', label='TravelSignal(periodic)+0.5')
|
|
|
|
ax.plot(time2, mysignal2(time2)-0.5, '.-', label='TravelSignal-0.5')
|
|
|
|
|
|
|
|
ax.legend()
|
|
|
|
|
|
|
|
plt.show();
|