#!/usr/bin/env python3
# vim: fdm=indent ts=4

__doc__ = \
"""
Show the curve for signal-to-noise ratio vs N_samples
"""

import matplotlib.pyplot as plt
import numpy as np

from mylib import *

rng = np.random.default_rng()

def noisy_sine_realisation_snr(
        N = 1,
        f_sample = 250e6, # Hz
        t_length = 1e4 * 1e-9, # s

        noise_band = passband(30e6, 80e6),
        noise_sigma = 1,

    # signal properties
        f_sine = 50e6,
        signal_band = passband(50e6 - 1e6, 50e6 + 1e6),
        sine_amp = 0.2,
        sine_offset = 0,
        return_ranges_plot = False,
        cut_signal_band_from_noise_band = False,
        rng=rng
    ):
    """
    Return N signal to noise ratios determined on
    N different noise + sine realisations.
    """
    N = int(N)

    init_params = np.array([sine_amp, f_sine, None, sine_offset])

    axs = None
    snrs = np.zeros( N )
    time = sampled_time(f_sample, end=t_length)
    for j in range(N):
        samples, noise = noisy_sine_sampling(time, init_params, noise_sigma, rng=rng)


        # determine signal to noise
        noise_power = bandpower(noise, f_sample, noise_band)
        if cut_signal_band_from_noise_band:
            lower_noise_band = passband(noise_band[0], signal_band[0])
            upper_noise_band = passband(signal_band[1], noise_band[1])

            noise_power = bandpower(noise, f_sample, lower_noise_band)
            noise_power += bandpower(noise, f_sample, upper_noise_band)

        signal_power = bandpower(samples, f_sample, signal_band)

        snrs[j] = np.sqrt(signal_power/noise_power)

        # make a nice plot showing what ranges were taken
        # and the bandpowers associated with them
        if return_ranges_plot and j == 0:
            combined_fft, freqs = ft_spectrum(samples+noise, f_sample)
            freq_scaler=1

            # plot the original signal
            if False:
                _, ax = plt.subplots()
                ax = plot_signal(samples+noise, sample_rate=f_sample/freq_scaler, time_unit='us', ax=ax)

            # plot the spectrum
            if True:
                _, axs = plot_combined_spectrum(combined_fft, freqs, freq_scaler=freq_scaler, freq_unit='MHz')

                # indicate band ranges and frequency
                for ax in axs:
                    ax.axvline(f_sine/freq_scaler, color='r', alpha=0.4)
                    ax.axvspan(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, color='purple', alpha=0.3, label='noiseband')
                    ax.axvspan(signal_band[0]/freq_scaler, signal_band[1]/freq_scaler, color='orange', alpha=0.3, label='signalband')

                # indicate initial phase
                axs[1].axhline(init_params[2], color='r', alpha=0.4)

                # plot the band powers
                if False:
                    powerax = axs[0].twinx()
                    powerax.set_ylabel("Bandpower")
                else:
                    powerax = axs[0]

                powerax.hlines(np.sqrt(signal_power), noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, colors=['orange'], zorder=5)
                powerax.hlines(np.sqrt(noise_power), noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, colors=['purple'], zorder=5)

                powerax.set_ylim(bottom=0)

            axs[0].legend()

            # plot signal_band pass signal
            if True:
                freqs = np.fft.fftfreq(len(samples), 1/f_sample)
                bandmask = bandpass_mask(freqs, band=signal_band)
                fft = np.fft.fft(samples)
                fft[ ~bandmask ] = 0
                bandpassed_samples = np.fft.ifft(fft)

                _, ax3 = plt.subplots()
                ax3 = plot_signal(bandpassed_samples, sample_rate=f_sample/freq_scaler, time_unit='us', ax=ax3)
                ax3.set_title("Bandpassed Signal")


    return snrs, axs


if __name__ == "__main__":
    from argparse import ArgumentParser
    from myscriptlib import save_all_figs_to_path_or_show

    rng = np.random.default_rng(1)

    parser = ArgumentParser(description=__doc__)
    parser.add_argument("fname", metavar="path/to/figure[/]", nargs="?", help="Location for generated figure, will append __file__ if a directory. If not supplied, figure is shown.")

    args = parser.parse_args()
    default_extensions = ['.pdf', '.png']

    if args.fname == 'none':
        args.fname = None

    ###
    t_lengths = np.linspace(1, 50, 50) # us
    N = 50e0
    fs_sine = [33.3, 50, 73.3] # MHz
    fs_sample = [250, 500] # MHz
    if False:
        # show t_length and fs_sample really don't care
        fs_iter = [ (fs_sample[0], f_sine, t_lengths) for f_sine in fs_sine ]
        fs_iter2 = [ (fs_sample[1], f_sine, t_lengths/2) for f_sine in fs_sine ]

        fs_iter += fs_iter2
        del fs_iter2
    else:
        fs_iter = [ (f_sample, f_sine, t_lengths) for f_sample in fs_sample for f_sine in fs_sine ]

    if False:
        f_sine = fs_sine[0]
        f_sample = fs_sample[0]
        N = 1 # Note: keep this low, N figures will be displayed!
        N_t_length = 10
        for t_length in t_lengths[-N_t_length-1:-1]:
            snrs = np.zeros( int(N))
            for i in range(int(N)):
                delta_f = 1/t_length
                signal_band = passband(f_sine- 3*delta_f, f_sine + 3*delta_f)
                noise_band = passband(30, 80) # MHz

                snrs[i], axs = noisy_sine_realisation_snr(
                        N=1,
                        t_length=t_length,
                        f_sample=f_sample,

                        # signal properties
                        f_sine = fs_sine[0],
                        sine_amp = 1,
                        noise_sigma = 1,

                        noise_band = noise_band,
                        signal_band = signal_band,

                        return_ranges_plot=False,
                        rng=rng,
                        )

                axs[0].set_title("SNR: {}, N:{}".format(snrs[i], t_length*f_sample))
                axs[0].set_xlim(
                        (f_sine - 20*delta_f)/1e6,
                        (f_sine + 20*delta_f)/1e6
                        )

            print(snrs, "M:",np.mean(snrs))

            plt.show(block=False)
    else:
        #original code
        sine_amp = 1
        noise_sigma = 4

        my_snrs = np.zeros( (len(fs_iter), len(t_lengths), int(N)) )
        for i, (f_sample, f_sine, t_lengths) in enumerate( fs_iter ):
           for k, t_length in enumerate(t_lengths):
               return_ranges_plot = ((k==0) and not True) or ( (k==(len(t_lengths)-1)) and True) and i < 1

               delta_f = 1/t_length
               signal_band = passband( *(f_sine + 2*delta_f*np.array([-1,1])) )
               noise_band=passband(30, 80) # MHz

               my_snrs[i,k], axs = noisy_sine_realisation_snr(
                       N=N,
                       t_length=t_length,
                       f_sample = f_sample,

                       # signal properties
                       f_sine = f_sine,
                       sine_amp = sine_amp,
                       noise_sigma = noise_sigma,

                       noise_band = noise_band,
                       signal_band = signal_band,

                       return_ranges_plot=return_ranges_plot,
                       rng=rng
                       )

               if return_ranges_plot:
                   ranges_axs = axs

        # plot the snrs
        fig, axs2 = plt.subplots()
        fig.basefname="signal_to_noise_vs_N"
        axs2.set_title("A: {:.2e}, $\\sigma$: {:.2e}".format(sine_amp, noise_sigma))
        axs2.set_xlabel("$N = T*f_s$")
        axs2.set_ylabel("SNR")

        mycolors = {}
        myshapes = { 250: '^', 500: 'v' }
        for i, (f_sample, f_sine, t_lengths) in enumerate(fs_iter):

            if f_sine in mycolors.keys():
                color = mycolors[f_sine]
            else:
                color = None

            if f_sample in myshapes.keys():
                marker = myshapes[f_sample]
            else:
                marker = 'x'

            # plot the means
            l = axs2.plot(t_lengths*f_sample, np.mean(my_snrs[i], axis=-1), color=color, marker=marker, ls='none', label='f:{}MHz, fs:{}MHz'.format(f_sine, f_sample), markeredgecolor='black', mew=0.1)

            color = l[0].get_color()
            mycolors[f_sine] = color
            myshapes[f_sample] = l[0].get_marker()

            if True:
                for k, t_length in enumerate(t_lengths):
                    t_length = np.repeat(t_length * f_sample, my_snrs.shape[-1])
                    axs2.plot(t_length, my_snrs[i,k], ls='none', color=color, marker='o', alpha=max(0.01, 1/my_snrs.shape[-1]))

        
        axs2.legend()

        # plot snrs vs T
        fig, axs3 = plt.subplots()
        fig.basefname="signal_to_noise_vs_T"
        axs3.set_title("A: {:.2e}, $\\sigma$: {:.2e}".format(sine_amp, noise_sigma))
        axs3.set_xlabel("time [us]")
        axs3.set_ylabel("SNR")

        #mycolors = {}
        #myshapes = { 250: '^', 500: 'v' }
        for i, (f_sample, f_sine, t_lengths) in enumerate(fs_iter):

            if f_sine in mycolors.keys():
                color = mycolors[f_sine]
            else:
                color = None

            if f_sample in myshapes.keys():
                marker = myshapes[f_sample]
            else:
                marker = 'x'

            # plot the means
            l = axs3.plot(t_lengths, np.mean(my_snrs[i], axis=-1), color=color, marker=marker, ls='none', label='f:{}MHz, fs:{}MHz'.format(f_sine, f_sample), markeredgecolor='black', mew=1)

            color = l[0].get_color()
            mycolors[f_sine] = color
            myshapes[f_sample] = l[0].get_marker()

            for k, t_length in enumerate(t_lengths):
                t_length = np.repeat(t_length , my_snrs.shape[-1])
                axs3.plot(t_length, my_snrs[i,k], ls='none', color=color, marker='o', alpha=max(0.01, 1/my_snrs.shape[-1]))


        axs3.legend()

    ### Save or show figures
    save_all_figs_to_path_or_show(args.fname, default_basename=__file__, default_extensions=default_extensions)