SNR figure: return power from bandlevel: sum(fft**2)

This commit is contained in:
Eric Teunis de Boone 2022-10-27 21:53:26 +02:00
parent 820f58d901
commit d23f8adff2

View file

@ -2,7 +2,7 @@
__doc__ = \ __doc__ = \
""" """
Show Show the curve for signal-to-noise ratio vs N_samples
""" """
from collections import namedtuple from collections import namedtuple
@ -18,7 +18,7 @@ passband = namedtuple("Band", ['low', 'high'], defaults=[0, np.inf])
def get_freq_spec(val,dt): def get_freq_spec(val,dt):
"""From earsim/tools.py""" """From earsim/tools.py"""
fval = np.abs(np.fft.fft(val))[:len(val)//2] fval = np.fft.fft(val)[:len(val)//2]
freq = np.fft.fftfreq(len(val),dt)[:len(val)//2] freq = np.fft.fftfreq(len(val),dt)[:len(val)//2]
return fval, freq return fval, freq
@ -163,7 +163,8 @@ def save_all_figs_to_path(fnames, figs=None, default_basename=__file__, default_
if len(fnames) == len(figs): if len(fnames) == len(figs):
fnames_list = zip(figs, fnames, False) fnames_list = zip(figs, fnames, False)
elif len(fnames) == 1: elif len(fnames) == 1:
fnames_list = ( (fig, fnames[0], len(figs) > 1) for fig in figs) tmp_fname = fnames[0] #needed for generator
fnames_list = ( (fig, tmp_fname, len(figs) > 1) for fig in figs)
else: else:
# outer product magic # outer product magic
fnames_list = ( (fig,fname, False) for fname in fnames for fig in figs ) fnames_list = ( (fig,fname, False) for fname in fnames for fig in figs )
@ -224,7 +225,7 @@ def bandlevel(samples, samplerate=1, band=passband(), normalise_bandsize=True, *
else: else:
bins = 1 bins = 1
level = np.sum(np.abs(fft[bandmask])) level = np.sum(np.abs(fft[bandmask])**2)
return level/bins return level/bins
@ -277,7 +278,7 @@ def main(
signal_level = bandlevel(samples, f_sample, signal_band) signal_level = bandlevel(samples, f_sample, signal_band)
snrs[j] = signal_level/noise_level snrs[j] = np.sqrt(signal_level/noise_level)
# make a nice plot showing what ranges were taken # make a nice plot showing what ranges were taken
# and the bandlevels associated with them # and the bandlevels associated with them
@ -346,12 +347,12 @@ if __name__ == "__main__":
### ###
t_lengths = np.linspace(1e3, 5e4)* 1e-9 # s t_lengths = np.linspace(1e3, 5e4)* 1e-9 # s
N = 10e1 N = 10e1
f_sine = 53e6 # Hz f_sine = 53.3e6 # Hz
f_sample = 250e6 # Hz f_sample = 250e6 # Hz
if True: if False:
N = 2 # Note: keep this low, N figures will be displayed! N = 1 # Note: keep this low, N figures will be displayed!
N_t_length = 2 N_t_length = 10
for t_length in t_lengths[-N_t_length-1:-1]: for t_length in t_lengths[-N_t_length-1:-1]:
snrs = np.zeros( int(N)) snrs = np.zeros( int(N))
for i in range(int(N)): for i in range(int(N)):
@ -380,7 +381,7 @@ if __name__ == "__main__":
print(snrs, "M:",np.mean(snrs)) print(snrs, "M:",np.mean(snrs))
plt.show(block=True) plt.show(block=False)
else: else:
#original code #original code
@ -416,7 +417,8 @@ if __name__ == "__main__":
for j, t_length in enumerate(t_lengths): for j, t_length in enumerate(t_lengths):
t_length = t_length * f_sample t_length = t_length * f_sample
axs2.plot(np.repeat(t_length, my_snrs.shape[1]), my_snrs[j], ls='none', color='blue', marker='o', alpha=max(0.01, 1/my_snrs.shape[1])) axs2.plot(np.repeat(t_length, my_snrs.shape[1]), my_snrs[j], ls='none', color='blue', marker='o', alpha=max(0.01, 1/my_snrs.shape[1]))
axs2.plot(t_length, np.mean(my_snrs[j]), color='green', marker='*', ls='none') # plot the means
axs2.plot(t_lengths*f_sample, np.mean(my_snrs, axis=-1), color='green', marker='*', ls='none')
### Save or show figures ### Save or show figures
if not args.fname: if not args.fname: