import numpy as np from collections import namedtuple from lib import direct_fourier_transform as dtft import matplotlib.pyplot as plt # for debug plotting passband = namedtuple("passband", ['low', 'high'], defaults=[0, np.inf]) def get_freq_spec(val,dt): """From earsim/tools.py""" fval = np.fft.fft(val)[:len(val)//2] freq = np.fft.fftfreq(len(val),dt)[:len(val)//2] return fval, freq def bandpass_samples(samples, samplerate, band=passband()): """ Bandpass the samples with this passband. This is a hard filter. """ fft, freqs = get_freq_spec(samples, samplerate) fft[ ~ self.freq_mask(freqs) ] = 0 return np.fft.irfft(fft) def bandpass_mask(freqs, band=passband()): low_pass = abs(freqs) <= band[1] high_pass = abs(freqs) >= band[0] return low_pass & high_pass def bandpower(samples, samplerate=1, band=passband(), normalise_bandsize=True, debug_ax=False): bins = 0 fft, freqs = get_freq_spec(samples, 1/samplerate) bandmask = [False]*len(freqs) if band[1] is None: # Only a single frequency given # use a DTFT for finding the power time = np.arange(0, len(samples), 1/samplerate) real, imag = dtft(band[0], time, samples) power = np.sum(np.abs(real**2 + imag**2)) else: bandmask = bandpass_mask(freqs, band=band) if normalise_bandsize: bins = np.count_nonzero(bandmask, axis=-1) else: bins = 1 bins = max(1, bins) power = 1/bins * np.sum(np.abs(fft[bandmask])**2) # Prepare plotting variables if an Axes is supplied if debug_ax: if any(bandmask): min_f, max_f = min(freqs[bandmask]), max(freqs[bandmask]) else: min_f, max_f = 0, 0 if band[1] is None: min_f, max_f = band[0], band[0] if debug_ax is True: debug_ax = plt.gca() l = debug_ax.plot(freqs, np.abs(fft), alpha=0.9) amp = np.sqrt(power) if min_f != max_f: debug_ax.plot( [min_f, max_f], [amp, amp], alpha=0.7, color=l[0].get_color(), ls='dashed') debug_ax.axvspan(min_f, max_f, color=l[0].get_color(), alpha=0.2) else: debug_ax.plot( min_f, amp, '4', alpha=0.7, color=l[0].get_color(), ms=10) return power def signal_to_noise(samples, noise, samplerate=1, signal_band=passband(), noise_band=None, debug_ax=False, mode='sine'): if noise_band is None: noise_band = signal_band if noise is None: noise = samples if debug_ax is True: debug_ax = plt.gca() if mode == 'sine': noise_power = bandpower(noise, samplerate, noise_band, debug_ax=debug_ax) noise_amplitude = np.sqrt(noise_power) signal_power = bandpower(samples, samplerate, signal_band, debug_ax=debug_ax) signal_amplitude = np.sqrt(signal_power) elif mode == 'pulse': noise_amplitude = np.sqrt(np.mean(noise**2)) signal_amplitude = max(np.abs(samples)) if debug_ax: l1 = debug_ax.plot(noise, alpha=0.5) debug_ax.axhline(noise_amplitude, alpha=0.9, color=l1[0].get_color()) l2 = debug_ax.plot(samples, alpha=0.5) debug_ax.axhline(signal_amplitude, alpha=0.9, color=l2[0].get_color()) else: raise NotImplementedError("mode not in ['sine', 'pulse']") return signal_amplitude/noise_amplitude