SNR figure: return power from bandlevel: sum(fft**2)

This commit is contained in:
Eric Teunis de Boone 2022-10-27 21:53:26 +02:00
parent 820f58d901
commit d23f8adff2

View file

@ -2,7 +2,7 @@
__doc__ = \ __doc__ = \
""" """
Show Show the curve for signal-to-noise ratio vs N_samples
""" """
from collections import namedtuple from collections import namedtuple
@ -18,7 +18,7 @@ passband = namedtuple("Band", ['low', 'high'], defaults=[0, np.inf])
def get_freq_spec(val,dt): def get_freq_spec(val,dt):
"""From earsim/tools.py""" """From earsim/tools.py"""
fval = np.abs(np.fft.fft(val))[:len(val)//2] fval = np.fft.fft(val)[:len(val)//2]
freq = np.fft.fftfreq(len(val),dt)[:len(val)//2] freq = np.fft.fftfreq(len(val),dt)[:len(val)//2]
return fval, freq return fval, freq
@ -30,11 +30,11 @@ def ft_spectrum( signal, sample_rate=1, ftfunc=None, freqfunc=None, mask_bias=Fa
return get_freq_spec(signal, 1/sample_rate) return get_freq_spec(signal, 1/sample_rate)
n_samples = len(signal) n_samples = len(signal)
if ftfunc is None: if ftfunc is None:
real_signal = np.isrealobj(signal) real_signal = np.isrealobj(signal)
if False and real_signal: if False and real_signal:
ftfunc = ft.rfft ftfunc = ft.rfft
freqfunc = ft.rfftfreq freqfunc = ft.rfftfreq
else: else:
ftfunc = ft.fft ftfunc = ft.fft
@ -44,16 +44,16 @@ def ft_spectrum( signal, sample_rate=1, ftfunc=None, freqfunc=None, mask_bias=Fa
freqfunc = ft.fftfreq freqfunc = ft.fftfreq
normalisation = 2/len(signal) if normalise_amplitude else 1 normalisation = 2/len(signal) if normalise_amplitude else 1
spectrum = normalisation * ftfunc(signal) spectrum = normalisation * ftfunc(signal)
freqs = freqfunc(n_samples, 1/sample_rate) freqs = freqfunc(n_samples, 1/sample_rate)
if not mask_bias: if not mask_bias:
return spectrum, freqs return spectrum, freqs
else: else:
return spectrum[1:], freqs[1:] return spectrum[1:], freqs[1:]
def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_amplitude=None, ax=None, freq_unit="Hz", freq_scaler=1): def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_amplitude=None, ax=None, freq_unit="Hz", freq_scaler=1):
""" Plot a signal's spectrum on an Axis object""" """ Plot a signal's spectrum on an Axis object"""
plot_amplitude = plot_amplitude or (not plot_power and not plot_complex) plot_amplitude = plot_amplitude or (not plot_power and not plot_complex)
@ -61,7 +61,7 @@ def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_a
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
ax.set_title("Spectrum") ax.set_title("Spectrum")
ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" )) ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" ))
ylabel = "" ylabel = ""
@ -80,7 +80,7 @@ def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_a
if plot_power: if plot_power:
ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha) ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha)
if plot_amplitude: if plot_amplitude:
ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha) ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha)
@ -97,7 +97,7 @@ def plot_phase( spectrum, freqs, ylim_epsilon=0.5, ax=None, freq_unit="Hz", freq
ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-') ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-')
ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon) ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon)
return ax return ax
def plot_signal( signal, sample_rate = 1, ax=None, time=None, time_unit="s", **kwargs): def plot_signal( signal, sample_rate = 1, ax=None, time=None, time_unit="s", **kwargs):
@ -112,13 +112,13 @@ def plot_signal( signal, sample_rate = 1, ax=None, time=None, time_unit="s", **k
ax.set_ylabel("A(t)") ax.set_ylabel("A(t)")
ax.plot(time, signal, **kwargs) ax.plot(time, signal, **kwargs)
return ax return ax
def plot_combined_spectrum(spectrum, freqs, def plot_combined_spectrum(spectrum, freqs,
spectrum_kwargs={}, fig=None, gs=None, freq_scaler=1, freq_unit="Hz"): spectrum_kwargs={}, fig=None, gs=None, freq_scaler=1, freq_unit="Hz"):
"""Plot both the frequencies and phase in one figure.""" """Plot both the frequencies and phase in one figure."""
# configure plotting layout # configure plotting layout
if fig is None: if fig is None:
fig = plt.figure(figsize=(8, 16)) fig = plt.figure(figsize=(8, 16))
@ -130,8 +130,8 @@ def plot_combined_spectrum(spectrum, freqs,
ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1) ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1)
axes = np.array([ax1, ax2]) axes = np.array([ax1, ax2])
# plot the spectrum # plot the spectrum
plot_spectrum(spectrum, freqs, ax=ax1, freq_scaler=freq_scaler, freq_unit=freq_unit, **spectrum_kwargs) plot_spectrum(spectrum, freqs, ax=ax1, freq_scaler=freq_scaler, freq_unit=freq_unit, **spectrum_kwargs)
# plot the phase # plot the phase
@ -139,13 +139,13 @@ def plot_combined_spectrum(spectrum, freqs,
ax1.xaxis.tick_top() ax1.xaxis.tick_top()
[label.set_visible(False) for label in ax1.get_xticklabels()] [label.set_visible(False) for label in ax1.get_xticklabels()]
return fig, axes return fig, axes
def phasemod(phase, low=np.pi): def phasemod(phase, low=np.pi):
""" """
Modulo phase such that it falls within the Modulo phase such that it falls within the
interval $[-low, 2\pi - low)$. interval $[-low, 2\pi - low)$.
""" """
return (phase + low) % (2*np.pi) - low return (phase + low) % (2*np.pi) - low
@ -153,7 +153,7 @@ def phasemod(phase, low=np.pi):
def save_all_figs_to_path(fnames, figs=None, default_basename=__file__, default_extensions=['.pdf', '.png']): def save_all_figs_to_path(fnames, figs=None, default_basename=__file__, default_extensions=['.pdf', '.png']):
if figs is None: if figs is None:
figs = [plt.figure(i) for i in plt.get_fignums()] figs = [plt.figure(i) for i in plt.get_fignums()]
default_basename = path.basename(default_basename) default_basename = path.basename(default_basename)
# singular value # singular value
@ -163,7 +163,8 @@ def save_all_figs_to_path(fnames, figs=None, default_basename=__file__, default_
if len(fnames) == len(figs): if len(fnames) == len(figs):
fnames_list = zip(figs, fnames, False) fnames_list = zip(figs, fnames, False)
elif len(fnames) == 1: elif len(fnames) == 1:
fnames_list = ( (fig, fnames[0], len(figs) > 1) for fig in figs) tmp_fname = fnames[0] #needed for generator
fnames_list = ( (fig, tmp_fname, len(figs) > 1) for fig in figs)
else: else:
# outer product magic # outer product magic
fnames_list = ( (fig,fname, False) for fname in fnames for fig in figs ) fnames_list = ( (fig,fname, False) for fname in fnames for fig in figs )
@ -224,7 +225,7 @@ def bandlevel(samples, samplerate=1, band=passband(), normalise_bandsize=True, *
else: else:
bins = 1 bins = 1
level = np.sum(np.abs(fft[bandmask])) level = np.sum(np.abs(fft[bandmask])**2)
return level/bins return level/bins
@ -265,7 +266,7 @@ def main(
for j in range(N): for j in range(N):
samples, noise = noisy_sine_sampling(time, init_params, noise_sigma) samples, noise = noisy_sine_sampling(time, init_params, noise_sigma)
# determine signal to noise # determine signal to noise
noise_level = bandlevel(noise, f_sample, noise_band) noise_level = bandlevel(noise, f_sample, noise_band)
if cut_signal_band_from_noise_band: if cut_signal_band_from_noise_band:
@ -277,7 +278,7 @@ def main(
signal_level = bandlevel(samples, f_sample, signal_band) signal_level = bandlevel(samples, f_sample, signal_band)
snrs[j] = signal_level/noise_level snrs[j] = np.sqrt(signal_level/noise_level)
# make a nice plot showing what ranges were taken # make a nice plot showing what ranges were taken
# and the bandlevels associated with them # and the bandlevels associated with them
@ -293,23 +294,23 @@ def main(
if True: if True:
freq_scaler=1e6 freq_scaler=1e6
_, axs = plot_combined_spectrum(combined_fft, freqs, freq_scaler=freq_scaler, freq_unit='MHz') _, axs = plot_combined_spectrum(combined_fft, freqs, freq_scaler=freq_scaler, freq_unit='MHz')
# indicate band ranges and frequency # indicate band ranges and frequency
for ax in axs: for ax in axs:
ax.axvline(f_sine/freq_scaler, color='r', alpha=0.4) ax.axvline(f_sine/freq_scaler, color='r', alpha=0.4)
ax.axvspan(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, color='purple', alpha=0.3, label='noiseband') ax.axvspan(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, color='purple', alpha=0.3, label='noiseband')
ax.axvspan(signal_band[0]/freq_scaler, signal_band[1]/freq_scaler, color='orange', alpha=0.3, label='signalband') ax.axvspan(signal_band[0]/freq_scaler, signal_band[1]/freq_scaler, color='orange', alpha=0.3, label='signalband')
# indicate initial phase # indicate initial phase
axs[1].axhline(init_params[2], color='r', alpha=0.4) axs[1].axhline(init_params[2], color='r', alpha=0.4)
# plot the band levels # plot the band levels
levelax = axs[0].twinx() levelax = axs[0].twinx()
levelax.set_ylabel("Bandlevel") levelax.set_ylabel("Bandlevel")
levelax.hlines(signal_level, noise_band[0]/freq_scaler, signal_band[1]/freq_scaler, colors=['orange']) levelax.hlines(signal_level, noise_band[0]/freq_scaler, signal_band[1]/freq_scaler, colors=['orange'])
levelax.hlines(noise_level, noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, colors=['purple']) levelax.hlines(noise_level, noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, colors=['purple'])
levelax.set_ylim(bottom=0) levelax.set_ylim(bottom=0)
axs[0].legend() axs[0].legend()
# plot signal_band pass signal # plot signal_band pass signal
@ -319,7 +320,7 @@ def main(
fft = np.fft.fft(samples) fft = np.fft.fft(samples)
fft[ ~bandmask ] = 0 fft[ ~bandmask ] = 0
bandpassed_samples = np.fft.ifft(fft) bandpassed_samples = np.fft.ifft(fft)
_, ax3 = plt.subplots() _, ax3 = plt.subplots()
ax3 = plot_signal(bandpassed_samples, sample_rate=f_sample/1e6, time_unit='us', ax=ax3) ax3 = plot_signal(bandpassed_samples, sample_rate=f_sample/1e6, time_unit='us', ax=ax3)
ax3.set_title("Bandpassed Signal") ax3.set_title("Bandpassed Signal")
@ -346,12 +347,12 @@ if __name__ == "__main__":
### ###
t_lengths = np.linspace(1e3, 5e4)* 1e-9 # s t_lengths = np.linspace(1e3, 5e4)* 1e-9 # s
N = 10e1 N = 10e1
f_sine = 53e6 # Hz f_sine = 53.3e6 # Hz
f_sample = 250e6 # Hz f_sample = 250e6 # Hz
if True: if False:
N = 2 # Note: keep this low, N figures will be displayed! N = 1 # Note: keep this low, N figures will be displayed!
N_t_length = 2 N_t_length = 10
for t_length in t_lengths[-N_t_length-1:-1]: for t_length in t_lengths[-N_t_length-1:-1]:
snrs = np.zeros( int(N)) snrs = np.zeros( int(N))
for i in range(int(N)): for i in range(int(N)):
@ -360,18 +361,18 @@ if __name__ == "__main__":
N=1, N=1,
t_length=t_length, t_length=t_length,
f_sample=f_sample, f_sample=f_sample,
# signal properties # signal properties
f_sine = f_sine, f_sine = f_sine,
sine_amp = 1, sine_amp = 1,
noise_sigma = 1, noise_sigma = 1,
noise_band = passband(30e6, 80e6), noise_band = passband(30e6, 80e6),
signal_band = passband(f_sine- 3*delta_f, f_sine + 3*delta_f), signal_band = passband(f_sine- 3*delta_f, f_sine + 3*delta_f),
return_ranges_plot=True return_ranges_plot=True
) )
axs[0].set_title("SNR: {}, N:{}".format(snrs[i], t_length*f_sample)) axs[0].set_title("SNR: {}, N:{}".format(snrs[i], t_length*f_sample))
axs[0].set_xlim( axs[0].set_xlim(
(f_sine - 20*delta_f)/1e6, (f_sine - 20*delta_f)/1e6,
@ -380,43 +381,44 @@ if __name__ == "__main__":
print(snrs, "M:",np.mean(snrs)) print(snrs, "M:",np.mean(snrs))
plt.show(block=True) plt.show(block=False)
else: else:
#original code #original code
my_snrs = np.zeros( (len(t_lengths), int(N)) ) my_snrs = np.zeros( (len(t_lengths), int(N)) )
for j, t_length in enumerate(t_lengths): for j, t_length in enumerate(t_lengths):
return_ranges_plot = ((j==0) and True) or ( (j==(len(t_lengths)-1)) and True) return_ranges_plot = ((j==0) and True) or ( (j==(len(t_lengths)-1)) and True)
delta_f = 1/t_length delta_f = 1/t_length
my_snrs[j], axs = main( my_snrs[j], axs = main(
N=N, N=N,
t_length=t_length, t_length=t_length,
f_sample = f_sample, f_sample = f_sample,
# signal properties # signal properties
f_sine = f_sine, f_sine = f_sine,
sine_amp = 1, sine_amp = 1,
noise_sigma = 1, noise_sigma = 1,
noise_band = passband(30e6, 80e6), noise_band = passband(30e6, 80e6),
signal_band = passband(f_sine- 3*delta_f, f_sine + 3*delta_f), signal_band = passband(f_sine- 3*delta_f, f_sine + 3*delta_f),
return_ranges_plot=return_ranges_plot, return_ranges_plot=return_ranges_plot,
) )
if return_ranges_plot: if return_ranges_plot:
ranges_axs = axs ranges_axs = axs
fig, axs2 = plt.subplots() fig, axs2 = plt.subplots()
axs2.set_xlabel("N = T*$f_s$") axs2.set_xlabel("N = T*$f_s$")
axs2.set_ylabel("SNR") axs2.set_ylabel("SNR")
for j, t_length in enumerate(t_lengths): for j, t_length in enumerate(t_lengths):
t_length = t_length * f_sample t_length = t_length * f_sample
axs2.plot(np.repeat(t_length, my_snrs.shape[1]), my_snrs[j], ls='none', color='blue', marker='o', alpha=max(0.01, 1/my_snrs.shape[1])) axs2.plot(np.repeat(t_length, my_snrs.shape[1]), my_snrs[j], ls='none', color='blue', marker='o', alpha=max(0.01, 1/my_snrs.shape[1]))
axs2.plot(t_length, np.mean(my_snrs[j]), color='green', marker='*', ls='none') # plot the means
axs2.plot(t_lengths*f_sample, np.mean(my_snrs, axis=-1), color='green', marker='*', ls='none')
### Save or show figures ### Save or show figures
if not args.fname: if not args.fname: