2023-01-12 13:46:18 +01:00
|
|
|
import numpy as np
|
|
|
|
from collections import namedtuple
|
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
from lib import direct_fourier_transform as dtft
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt # for debug plotting
|
|
|
|
|
2023-01-12 13:46:18 +01:00
|
|
|
passband = namedtuple("passband", ['low', 'high'], defaults=[0, np.inf])
|
|
|
|
|
|
|
|
def get_freq_spec(val,dt):
|
|
|
|
"""From earsim/tools.py"""
|
|
|
|
fval = np.fft.fft(val)[:len(val)//2]
|
|
|
|
freq = np.fft.fftfreq(len(val),dt)[:len(val)//2]
|
|
|
|
return fval, freq
|
|
|
|
|
|
|
|
|
|
|
|
def bandpass_samples(samples, samplerate, band=passband()):
|
|
|
|
"""
|
|
|
|
Bandpass the samples with this passband.
|
|
|
|
This is a hard filter.
|
|
|
|
"""
|
|
|
|
fft, freqs = get_freq_spec(samples, samplerate)
|
|
|
|
|
|
|
|
fft[ ~ self.freq_mask(freqs) ] = 0
|
|
|
|
|
|
|
|
return np.fft.irfft(fft)
|
|
|
|
|
|
|
|
def bandpass_mask(freqs, band=passband()):
|
|
|
|
low_pass = abs(freqs) <= band[1]
|
|
|
|
high_pass = abs(freqs) >= band[0]
|
|
|
|
|
|
|
|
return low_pass & high_pass
|
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
def bandpower(samples, samplerate=1, band=passband(), normalise_bandsize=True, debug_ax=False):
|
|
|
|
bins = 0
|
|
|
|
fft, freqs = get_freq_spec(samples, 1/samplerate)
|
|
|
|
bandmask = [False]*len(freqs)
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
if band[1] is None:
|
|
|
|
# Only a single frequency given
|
|
|
|
# use a DTFT for finding the power
|
|
|
|
time = np.arange(0, len(samples), 1/samplerate)
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
real, imag = dtft(band[0], time, samples)
|
|
|
|
power = np.sum(np.abs(real**2 + imag**2))
|
2023-01-12 13:46:18 +01:00
|
|
|
else:
|
2023-02-02 08:49:05 +01:00
|
|
|
bandmask = bandpass_mask(freqs, band=band)
|
|
|
|
|
|
|
|
if normalise_bandsize:
|
|
|
|
bins = np.count_nonzero(bandmask, axis=-1)
|
|
|
|
else:
|
|
|
|
bins = 1
|
|
|
|
|
|
|
|
bins = max(1, bins)
|
|
|
|
|
|
|
|
power = 1/bins * np.sum(np.abs(fft[bandmask])**2)
|
|
|
|
|
|
|
|
# Prepare plotting variables if an Axes is supplied
|
|
|
|
if debug_ax:
|
|
|
|
if any(bandmask):
|
|
|
|
min_f, max_f = min(freqs[bandmask]), max(freqs[bandmask])
|
|
|
|
else:
|
|
|
|
min_f, max_f = 0, 0
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
if band[1] is None:
|
|
|
|
min_f, max_f = band[0], band[0]
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
if debug_ax is True:
|
|
|
|
debug_ax = plt.gca()
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
l = debug_ax.plot(freqs, np.abs(fft), alpha=0.9)
|
|
|
|
|
|
|
|
amp = np.sqrt(power)
|
|
|
|
|
|
|
|
if min_f != max_f:
|
|
|
|
debug_ax.plot( [min_f, max_f], [amp, amp], alpha=0.7, color=l[0].get_color(), ls='dashed')
|
|
|
|
debug_ax.axvspan(min_f, max_f, color=l[0].get_color(), alpha=0.2)
|
|
|
|
else:
|
|
|
|
debug_ax.plot( min_f, amp, '4', alpha=0.7, color=l[0].get_color(), ms=10)
|
|
|
|
|
|
|
|
return power
|
|
|
|
|
2023-04-13 17:09:30 +02:00
|
|
|
def signal_to_noise(samples, noise, samplerate=1, signal_band=passband(), noise_band=None, debug_ax=False, mode='sine'):
|
2023-01-12 13:46:18 +01:00
|
|
|
if noise_band is None:
|
|
|
|
noise_band = signal_band
|
|
|
|
|
|
|
|
if noise is None:
|
|
|
|
noise = samples
|
|
|
|
|
2023-02-02 08:49:05 +01:00
|
|
|
if debug_ax is True:
|
|
|
|
debug_ax = plt.gca()
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-04-13 17:09:30 +02:00
|
|
|
if mode == 'sine':
|
|
|
|
noise_power = bandpower(noise, samplerate, noise_band, debug_ax=debug_ax)
|
|
|
|
noise_amplitude = np.sqrt(noise_power)
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-04-13 17:09:30 +02:00
|
|
|
signal_power = bandpower(samples, samplerate, signal_band, debug_ax=debug_ax)
|
|
|
|
signal_amplitude = np.sqrt(signal_power)
|
2023-01-12 13:46:18 +01:00
|
|
|
|
2023-04-13 17:09:30 +02:00
|
|
|
elif mode == 'pulse':
|
|
|
|
noise_amplitude = np.sqrt(np.mean(noise**2))
|
|
|
|
|
|
|
|
signal_amplitude = max(np.abs(samples))
|
|
|
|
|
|
|
|
if debug_ax:
|
|
|
|
l1 = debug_ax.plot(noise, alpha=0.5)
|
|
|
|
debug_ax.axhline(noise_amplitude, alpha=0.9, color=l1[0].get_color())
|
|
|
|
|
|
|
|
l2 = debug_ax.plot(samples, alpha=0.5)
|
|
|
|
debug_ax.axhline(signal_amplitude, alpha=0.9, color=l2[0].get_color())
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("mode not in ['sine', 'pulse']")
|
|
|
|
|
|
|
|
return signal_amplitude/noise_amplitude
|