m-thesis-introduction/fourier/signal_to_noise.py

426 lines
13 KiB
Python

#!/usr/bin/env python3
__doc__ = \
"""
Show
"""
from collections import namedtuple
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import scipy.fftpack as ft
rng = np.random.default_rng()
passband = namedtuple("Band", ['low', 'high'], defaults=[0, np.inf])
def get_freq_spec(val,dt):
"""From earsim/tools.py"""
fval = np.abs(np.fft.fft(val))[:len(val)//2]
freq = np.fft.fftfreq(len(val),dt)[:len(val)//2]
return fval, freq
def ft_spectrum( signal, sample_rate=1, ftfunc=None, freqfunc=None, mask_bias=False, normalise_amplitude=False):
"""Return a FT of $signal$, with corresponding frequencies"""
if True:
return get_freq_spec(signal, 1/sample_rate)
n_samples = len(signal)
if ftfunc is None:
real_signal = np.isrealobj(signal)
if False and real_signal:
ftfunc = ft.rfft
freqfunc = ft.rfftfreq
else:
ftfunc = ft.fft
freqfunc = ft.fftfreq
if freqfunc is None:
freqfunc = ft.fftfreq
normalisation = 2/len(signal) if normalise_amplitude else 1
spectrum = normalisation * ftfunc(signal)
freqs = freqfunc(n_samples, 1/sample_rate)
if not mask_bias:
return spectrum, freqs
else:
return spectrum[1:], freqs[1:]
def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_amplitude=None, ax=None, freq_unit="Hz", freq_scaler=1):
""" Plot a signal's spectrum on an Axis object"""
plot_amplitude = plot_amplitude or (not plot_power and not plot_complex)
alpha = 1
if ax is None:
ax = plt.gca()
ax.set_title("Spectrum")
ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" ))
ylabel = ""
if plot_amplitude or plot_complex:
ylabel = "Amplitude"
if plot_power:
if ylabel:
ylabel += "|"
ylabel += "Power"
ax.set_ylabel(ylabel)
if plot_complex:
alpha = 0.5
ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha)
ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha)
if plot_power:
ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha)
if plot_amplitude:
ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha)
ax.legend()
return ax
def plot_phase( spectrum, freqs, ylim_epsilon=0.5, ax=None, freq_unit="Hz", freq_scaler=1):
if ax is None:
ax = plt.gca()
ax.set_ylabel("Phase")
ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" ))
ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-')
ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon)
return ax
def plot_signal( signal, sample_rate = 1, ax=None, time=None, time_unit="s", **kwargs):
if ax is None:
ax = plt.gca()
if time is None:
time = np.arange(len(signal))/sample_rate
ax.set_title("Signal")
ax.set_xlabel("t" + (" ["+time_unit+"]" if time_unit else "" ))
ax.set_ylabel("A(t)")
ax.plot(time, signal, **kwargs)
return ax
def plot_combined_spectrum(spectrum, freqs,
spectrum_kwargs={}, fig=None, gs=None, freq_scaler=1, freq_unit="Hz"):
"""Plot both the frequencies and phase in one figure."""
# configure plotting layout
if fig is None:
fig = plt.figure(figsize=(8, 16))
if gs is None:
gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0)
ax1 = fig.add_subplot(gs[:-1, -1])
ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1)
axes = np.array([ax1, ax2])
# plot the spectrum
plot_spectrum(spectrum, freqs, ax=ax1, freq_scaler=freq_scaler, freq_unit=freq_unit, **spectrum_kwargs)
# plot the phase
plot_phase(spectrum, freqs, ax=ax2, freq_scaler=freq_scaler, freq_unit=freq_unit)
ax1.xaxis.tick_top()
[label.set_visible(False) for label in ax1.get_xticklabels()]
return fig, axes
def phasemod(phase, low=np.pi):
"""
Modulo phase such that it falls within the
interval $[-low, 2\pi - low)$.
"""
return (phase + low) % (2*np.pi) - low
def save_all_figs_to_path(fnames, figs=None, default_basename=__file__, default_extensions=['.pdf', '.png']):
if figs is None:
figs = [plt.figure(i) for i in plt.get_fignums()]
default_basename = path.basename(default_basename)
# singular value
if isinstance(fnames, (str, True)):
fnames = [fnames]
if len(fnames) == len(figs):
fnames_list = zip(figs, fnames, False)
elif len(fnames) == 1:
fnames_list = ( (fig, fnames[0], len(figs) > 1) for fig in figs)
else:
# outer product magic
fnames_list = ( (fig,fname, False) for fname in fnames for fig in figs )
del fnames
# format fnames
pad_width = max(2, int(np.floor(np.log10(len(figs))+1)))
fig_fnames = []
for fig, fnames, append_num in fnames_list:
if not hasattr(fnames, '__len__') or isinstance(fnames, str):
# single name
fnames = [fnames]
new_fnames = []
for fname in fnames:
if path.isdir(fname):
fname = path.join(fname, path.splitext(default_basename)[0]) # leave off extension
if append_num is True:
fname += ("_fig{:0"+str(pad_width)+"d}").format(fig.number)
if not path.splitext(fname)[1]: # no extension
for ext in default_extensions:
new_fnames.append(fname+ext)
else:
new_fnames.append(fname)
fig_fnames.append(new_fnames)
# save files
for fnames, fig in zip(fig_fnames, figs):
for fname in fnames:
fig.savefig(fname, transparent=True)
def sine_fitfunc(t, amp=1, freq=1, phase=0, off=0):
"""Simple sine wave for fitting purposes"""
return amp*np.sin( 2*np.pi*freq*t + phase) + off
def sampled_time(sample_rate=1, start=0, end=1, offset=0):
return offset + np.arange(start, end, 1/sample_rate)
def bandpass_mask(freqs, band=passband()):
low_pass = abs(freqs) <= band[1]
high_pass = abs(freqs) >= band[0]
return low_pass & high_pass
def bandsize(band = passband()):
return band[1] - band[0]
def bandlevel(samples, samplerate=1, band=passband(), normalise_bandsize=True, **ft_kwargs):
fft, freqs = ft_spectrum(samples, samplerate, **ft_kwargs)
bandmask = bandpass_mask(freqs, band=band)
if normalise_bandsize:
bins = np.count_nonzero(bandmask, axis=-1)
else:
bins = 1
level = np.sum(np.abs(fft[bandmask]))
return level/bins
def noisy_sine_sampling(time, init_params, noise_sigma=1, rng=rng):
if init_params[2] is None:
init_params[2] = phasemod(2*np.pi*rng.random())
samples = sine_fitfunc(time, *init_params)
noise = rng.normal(0, noise_sigma, size=len(samples))
return samples, noise
def main(
N = 1,
f_sample = 250e6, # Hz
t_length = 1e4 * 1e-9, # s
noise_band = passband(30e6, 80e6),
noise_sigma = 1,
# signal properties
f_sine = 50e6,
signal_band = passband(50e6 - 1e6, 50e6 + 1e6),
sine_amp = 0.2,
sine_offset = 0,
return_ranges_plot = False,
cut_signal_band_from_noise_band = False
):
N = int(N)
init_params = np.array([sine_amp, f_sine, None, sine_offset])
axs = None
snrs = np.zeros( N )
time = sampled_time(f_sample, end=t_length)
for j in range(N):
samples, noise = noisy_sine_sampling(time, init_params, noise_sigma)
# determine signal to noise
noise_level = bandlevel(noise, f_sample, noise_band)
if cut_signal_band_from_noise_band:
lower_noise_band = passband(noise_band[0], signal_band[0])
upper_noise_band = passband(signal_band[1], noise_band[1])
noise_level = bandlevel(noise, f_sample, lower_noise_band)
noise_level += bandlevel(noise, f_sample, upper_noise_band)
signal_level = bandlevel(samples, f_sample, signal_band)
snrs[j] = signal_level/noise_level
# make a nice plot showing what ranges were taken
# and the bandlevels associated with them
if return_ranges_plot and j == 0:
combined_fft, freqs = ft_spectrum(samples+noise, f_sample)
# plot the original signal
if False:
_, ax = plt.subplots()
ax = plot_signal(samples+noise, sample_rate=f_sample/1e6, time_unit='us', ax=ax)
# plot the spectrum
if True:
freq_scaler=1e6
_, axs = plot_combined_spectrum(combined_fft, freqs, freq_scaler=freq_scaler, freq_unit='MHz')
# indicate band ranges and frequency
for ax in axs:
ax.axvline(f_sine/freq_scaler, color='r', alpha=0.4)
ax.axvspan(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, color='purple', alpha=0.3, label='noiseband')
ax.axvspan(signal_band[0]/freq_scaler, signal_band[1]/freq_scaler, color='orange', alpha=0.3, label='signalband')
# indicate initial phase
axs[1].axhline(init_params[2], color='r', alpha=0.4)
# plot the band levels
levelax = axs[0].twinx()
levelax.set_ylabel("Bandlevel")
levelax.hlines(signal_level, noise_band[0]/freq_scaler, signal_band[1]/freq_scaler, colors=['orange'])
levelax.hlines(noise_level, noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, colors=['purple'])
levelax.set_ylim(bottom=0)
axs[0].legend()
# plot signal_band pass signal
if False:
freqs = np.fft.fftfreq(len(samples), 1/f_sample)
bandmask = bandpass_mask(freqs, band=signal_band)
fft = np.fft.fft(samples)
fft[ ~bandmask ] = 0
bandpassed_samples = np.fft.ifft(fft)
_, ax3 = plt.subplots()
ax3 = plot_signal(bandpassed_samples, sample_rate=f_sample/1e6, time_unit='us', ax=ax3)
ax3.set_title("Bandpassed Signal")
return snrs, axs
if __name__ == "__main__":
from argparse import ArgumentParser
import os.path as path
rng = np.random.default_rng(1)
parser = ArgumentParser(description=__doc__)
parser.add_argument("fname", metavar="path/to/figure[/]", nargs="?", help="Location for generated figure, will append __file__ if a directory. If not supplied, figure is shown.")
args = parser.parse_args()
default_extensions = ['.pdf', '.png']
if args.fname == 'none':
args.fname = None
###
t_lengths = np.linspace(1e3, 5e4)* 1e-9 # s
N = 10e1
f_sine = 53e6 # Hz
f_sample = 250e6 # Hz
if True:
N = 2 # Note: keep this low, N figures will be displayed!
N_t_length = 2
for t_length in t_lengths[-N_t_length-1:-1]:
snrs = np.zeros( int(N))
for i in range(int(N)):
delta_f = 1/t_length
snrs[i], axs = main(
N=1,
t_length=t_length,
f_sample=f_sample,
# signal properties
f_sine = f_sine,
sine_amp = 1,
noise_sigma = 1,
noise_band = passband(30e6, 80e6),
signal_band = passband(f_sine- 3*delta_f, f_sine + 3*delta_f),
return_ranges_plot=True
)
axs[0].set_title("SNR: {}, N:{}".format(snrs[i], t_length*f_sample))
axs[0].set_xlim(
(f_sine - 20*delta_f)/1e6,
(f_sine + 20*delta_f)/1e6
)
print(snrs, "M:",np.mean(snrs))
plt.show(block=True)
else:
#original code
my_snrs = np.zeros( (len(t_lengths), int(N)) )
for j, t_length in enumerate(t_lengths):
return_ranges_plot = ((j==0) and True) or ( (j==(len(t_lengths)-1)) and True)
delta_f = 1/t_length
my_snrs[j], axs = main(
N=N,
t_length=t_length,
f_sample = f_sample,
# signal properties
f_sine = f_sine,
sine_amp = 1,
noise_sigma = 1,
noise_band = passband(30e6, 80e6),
signal_band = passband(f_sine- 3*delta_f, f_sine + 3*delta_f),
return_ranges_plot=return_ranges_plot,
)
if return_ranges_plot:
ranges_axs = axs
fig, axs2 = plt.subplots()
axs2.set_xlabel("N = T*$f_s$")
axs2.set_ylabel("SNR")
for j, t_length in enumerate(t_lengths):
t_length = t_length * f_sample
axs2.plot(np.repeat(t_length, my_snrs.shape[1]), my_snrs[j], ls='none', color='blue', marker='o', alpha=max(0.01, 1/my_snrs.shape[1]))
axs2.plot(t_length, np.mean(my_snrs[j]), color='green', marker='*', ls='none')
### Save or show figures
if not args.fname:
# empty list, False, None
plt.show()
else:
save_all_figs_to_path(args.fname, default_basename=__file__)