#!/usr/bin/env python3
# vim: fdm=indent ts=4

__doc__ = \
"""
Sample sine wave + noise
Filter it
Then fit in t-domain to resolve \\varphi_0
"""

import matplotlib.pyplot as plt
import numpy as np
if not True:
    import numpy.fft as ft
else:
    import scipy.fftpack as ft
import scipy.optimize as opt
from scipy.signal import hilbert


from mylib import *

rng = np.random.default_rng()

def guess_sine_parameters(samples, fft=None, fft_freqs=None, guess=[None,None,None,None]):
    """
    Use crude methods to guess the parameters to a sine wave
    from properties of both samples and their fourier transform.

    Parameters:
    -----------
    samples - arraylike

    guess - arraylike or float or None
        If float, this is interpreted as a frequency
        Order of parameters: [amplitude, frequency, phase, baseline]
        If one parameter is None, it is filled with an approximate value if available.

    Returns:
    -----------
    guess - arraylike
        An updated version of init_guess: [amplitude, frequency, phase, baseline]
    """

    if not hasattr(guess, '__len__'):
        # interpret as a frequency (might still be None)
        guess = [None, guess, None, None]

    assert len(guess) == 4, "Wrong length for initial guess (should be 4)"

    nearest_f, nearest_phase = None, None
    if fft is not None and (guess[1] is None or guess[2] is None):
        nearest_idx = None
        if guess[1] is not None:
            if fft_freqs is not None:
                nearest_idx = find_nearest(guess[1], fft_freqs)
        else:
            # We'll take the strongest peak by default
            if fft is not None:
                nearest_idx = np.argmax(fft*2)

        if nearest_idx is not None:
            if fft_freqs is not None:
                nearest_f = fft_freqs[nearest_idx]

            nearest_phase = np.angle(fft[nearest_idx])

    for i in range(4):
        if guess[i] is not None:
            continue

        if i == 0: # amplitude
            if False:
                guess[i] = np.std(samples) * (2 ** 1/2)
            else:
                guess[i] = max(samples-np.mean(samples))
        elif i == 1: # frequency
            guess[i] = nearest_f
        elif i == 2: # phase
            guess[i] = nearest_phase
        elif i == 3: # baseline   samples
            guess[i] = np.mean(samples)

    return guess

def fit_sine_to_samples(time, samples, samplerate=1, bandpass=None, guess=[None,None,None,None], fitfunc=sine_fitfunc, fft=None, freqs=None, bounds=None, restrained_fit=False, **curve_kwargs):
    if bandpass is not None or guess[1] is None or guess[2] is None:
        if fft is None:
            fft = ft.rfft(samples)
        if freqs is None:
            freqs = ft.rfftfreq(samples.size, 1/samplerate)

        if bandpass:
            fft[(freqs < bandpass[0]) | (freqs > bandpass[1])] = 0
            samples = ft.irfft(fft, samples.size)

    guess = guess_sine_parameters(samples, fft=fft, fft_freqs=freqs, guess=guess)

    guess = np.array(guess)

    if restrained_fit:
        # Restrained fit
        # only allow phase to be fitted
        # Take the amplitude from the hilbert envelope of the (bandpassed) samples

        # References for lambda

        frequency = guess[1]
        baseline = guess[3]
        envelope = np.abs(hilbert(samples))
        base_fitfunc = fitfunc

        samples = samples/envelope

        fitfunc = lambda t, amplitude, phase: base_fitfunc(t, amp=amplitude, phase=phase, freq=frequency, baseline=baseline)

        old_guess = guess.copy()

        guess = guess[[0,2]]

        if bounds is None:
            sample_max = max(samples)

            low_bounds = np.array([0.8,-np.pi])
            high_bounds = np.array([1.2, np.pi])
        else:
            low_bounds =  bounds[0][[0,2]]
            high_bounds = bounds[1][[0,2]]

        bounds = (low_bounds, high_bounds)

    elif bounds is None :
        high_bounds = np.array([np.inf, np.inf, +1*np.pi, np.inf])
        low_bounds = -1*high_bounds

        bounds = (low_bounds, high_bounds)

    print(bounds, guess)

    try:
        fit = opt.curve_fit(fitfunc, time, samples, p0=guess, bounds=bounds, **curve_kwargs)
    except RuntimeError:
        fit = None

    if len(bounds[0]) == 1 or restrained_fit:
        # Restrained fitting was used
        # merge back into guess and fit

        guess = old_guess
        fit = [
                np.array([fit[0][0], old_guess[1], fit[0][1], old_guess[3]]),
                fit[1]
            ]

    return fit, guess, (fft, freqs, samples)

def chi_sq(observed, expected):
    """
    Simple \Chi^2 test
    """
    return np.sum( (observed-expected)**2  / expected)

def dof(observed, n_parameters=1):
    return len(observed) - n_parameters

def simulate_noisy_sine_fitting_SNR_and_residuals(
        N=1, snr_band=passband(), noise_band=passband(),
        t_length=1e-6, f_sample=250e6,
        noise_sigma=1, init_params=[1, 50e6, None, 0],
        show_original_signal_figure=False, show_bandpassed_signal_figure=True,
        restrained_fit=True
        ):
    residuals = np.empty( (int(N), len(init_params)) )
    real_snrs = np.empty( (int(N)) )

    axs1, axs2 = None, None
    for j, _ in enumerate(residuals):

        if j % 500 == 0:
            print("Iteration {} running".format(j))

        # set random phase
        init_params[2] = phasemod(2*np.pi*rng.random())

        samples = sine_fitfunc(time, *init_params)
        if noise_sigma: # noise
            noise = rng.normal(0,noise_sigma, size=(len(samples)))
        else:
            noise = np.zeros(len(samples))

        real_snrs[j] = signal_to_noise(samples, noise, signal_band=snr_band, samplerate=f_sample, noise_band=noise_band)

        # plot original
        if show_original_signal_figure and (j==0 or N == 1):
            fig, axs1 = plot_signal_and_spectrum(
                    samples+noise, f_sample, "Original",
                    freq_unit='MHz', freq_scaler=freq_scaler
                    )
            for ax in axs1[[1,2]]:
                ax.axvline(f_sine/freq_scaler, color='r', alpha=0.4) # f_beacon
                ax.axvspan(snr_band[0]/freq_scaler,snr_band[1]/freq_scaler, color='purple', alpha=0.3, label='signalband') # snr
                ax.axvspan(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, color='orange', alpha=0.3, label='noiseband') # noise_band

            # indicate initial phase
            axs1[2].axhline(init_params[2], color='r', alpha=0.4)

            axs1[1].legend()

        if False:
            # use initial_params as guess
            guess = init_params
        else:
            guess = [None, f_sine, None, None]
        fit, guess, (fft, freqs, bandpassed) = fit_sine_to_samples(time, samples+noise, f_sample, guess=guess, bandpass=snr_band, restrained_fit=restrained_fit)


        if fit is None:
            residuals[j] = np.nan
            continue

        residuals[j] = normalise_sine_params(init_params - fit[0])

        # figures
        if show_bandpassed_signal_figure and (j==0 or N == 1):
            analytic_signal = hilbert(bandpassed)
            envelope = np.abs(analytic_signal)
            instant_phase = np.angle(analytic_signal)

            fit_params = fit[0].tolist()
            fit_params[0] = envelope
            fitted_sine = sine_fitfunc(time, *fit_params)


            if False:
                fig4, axs4 = plt.subplots(2,1, sharex=True)
                fig4.suptitle("Bandpassed Hilbert")
                axs4[1].set_xlabel("Time")

                axs4[0].set_ylabel("Instant Phase")
                axs4[0].plot(time, instant_phase, marker='.')
                #axs4[0].axhline(init_params[2], color='r')


                axs4[1].set_ylabel("Instant Freq")
                axs4[1].plot(time[1:], np.diff(np.unwrap(instant_phase)) / (2*np.pi*f_sample), marker='.')
                #axs4[1].axhline(init_params[1], color='r')


            ## Next figure
            if True:
                fig2, axs2 = plot_signal_and_spectrum(
                        bandpassed, f_sample, "Bandpassed samples\nS/N:{:.2e}".format(real_snrs[j]),
                        freq_unit='MHz', freq_scaler=freq_scaler,
                        signal_kwargs=dict(alpha=0.8, time_unit='us')
                        )
                for ax in axs2[[1,2]]:
                    ax.axvline(f_sine/freq_scaler, color='r', alpha=0.4) # f_beacon
                    ax.axvspan(snr_band[0]/freq_scaler,snr_band[1]/freq_scaler, color='purple', alpha=0.3, label='signalband') # snr
                    ax.axvspan(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler, color='orange', alpha=0.3, label='noiseband') # noise_band

                l = axs2[0].plot(time, fitted_sine, label='fit', alpha=0.8)
                #axs2[0].text(1, 1, '$\chi/d.o.f. = {:.2e}/{:.2e}$'.format(chi_sq(fitted_sine, samples), dof(samples,4)), transform=axs2[0].transAxes, ha='right', va='top')

                axs2[0].plot(time, envelope, label='envelope')

                # indicate initial phase
                axs2[2].axhline(init_params[2], color='r', alpha=0.4)
                axs2[2].axhline(fit[0][2], color=l[0].get_color(), alpha=0.4)

                axs2[0].legend(loc='upper left')
                axs2[1].legend()


            if True:
                fig5, axs5 = plt.subplots(2,1, sharex=True)
                fig5.suptitle("Bandpassed Samples vs Model")
                axs5[0].set_ylabel("Amplitude")
                axs5[0].plot(bandpassed, label='samples', alpha=0.8)
                axs5[0].plot(fitted_sine, label='fit', alpha=0.8)
                axs5[0].plot(envelope, label='envelope')

                axs5[0].plot(samples, label='orig sine', alpha=0.8)

                axs5[0].legend()

                axs5[1].set_ylabel("Residuals")
                axs5[1].set_xlabel("Sample")
                axs5[1].plot(samples - fitted_sine, label="Sine - Model", alpha=0.8)
                axs5[1].plot(bandpassed - fitted_sine, label="Bandpassed - Model", alpha=0.8)

                axs5[1].legend()

            print("init:", init_params)
            print("fit :", fit[0])
            print("res :", residuals[j])

    return residuals, real_snrs, (axs1, axs2)


if __name__ == "__main__":
    from argparse import ArgumentParser
    from myscriptlib import save_all_figs_to_path_or_show

    parser = ArgumentParser(description=__doc__)
    parser.add_argument("fname", metavar="path/to/figure[/]", nargs="?", help="Location for generated figure, will append __file__ if a directory. If not supplied, figure is shown.")
    parser.add_argument("-n", "--n-rand", dest='N', default=1, type=int, nargs='?', help='Number of random sines to fit')
    parser.add_argument('--seed', default=1, type=int, help='RNG seed')
    args = parser.parse_args()
    default_extensions = ['.pdf', '.png']

    if args.fname == 'none':
        args.fname = None

    rng = np.random.default_rng(args.seed)

    report_N_nan = True
    restrained_fitting = True

    f_sine = 53.123456 # MHz
    sine_amplitude = 1
    sine_baseline = 0
    init_params = np.array([sine_amplitude, f_sine, None, sine_baseline])

    N = int(args.N)
    f_sample = 250 # MHz
    t_length = 10 # us
    noise_sigma = 0.01

    f_delta = 1/t_length
    noise_band = (30,80) # MHz
    snr_band = (f_sine -50*f_delta, f_sine + 50*f_delta)

    time = sampled_time(f_sample, end=t_length)

    freq_scaler = 1

    ###### End of inputs

    residuals, real_snrs, _ = simulate_noisy_sine_fitting_SNR_and_residuals(N=N, snr_band=snr_band, noise_band=noise_band, t_length=t_length, f_sample=f_sample, noise_sigma=noise_sigma, init_params=init_params, restrained_fit=restrained_fitting)

    # Filter NaNs from fit attempts that failed
    nan_mask = ~np.isnan(residuals).any(axis=1)
    if report_N_nan:
        ## report how many NaNs were found
        print("NaNs: {}/{}".format(np.count_nonzero(~nan_mask), len(real_snrs)))

    residuals = residuals[ nan_mask ]
    real_snrs = real_snrs [ nan_mask ]

    ## Plot Signal-to-Noise vs Residuals of the fit paramters
    fig, axs = plt.subplots(1,1 + 2*( not restrained_fitting), sharey=True)

    if not hasattr(axs,'__len__'):
        axs = [axs]

    fig.suptitle("S/N vs Residuals\nS/N Band ({:.2e},{:.2e})MHz \namp/sigma: {}".format(snr_band[0]/freq_scaler, snr_band[-1]/freq_scaler, sine_amplitude/ noise_sigma))
    axs[0].set_ylabel("S/N")
    j = 0 # plot counter
    for i in range(len(init_params)):
        if restrained_fitting and i in [0,1,3]:
            continue

        unit_scaler = [1, 1][i==1]
        unit_string = ['', '[MHz]'][i==1]
        xlabel = ["Amplitude", "Frequency", "Phase", "Baseline"][i]

        if i == 2:
            #axis_pi_ticker(axs[j].xaxis)
            axs[j].set_xlim(-np.pi, np.pi)


        real_snrs[np.isnan(real_snrs)] = 1 # Show nan values

        axs[j].set_xlabel(xlabel + unit_string)
        axs[j].plot(residuals[:,i]/unit_scaler, real_snrs, ls='none', marker='o', alpha=max(0.3, 1/len(real_snrs)))

        j += 1

    ## Plot Histograms of the Residuals
    if True and N > 1:
        for j in range(len(init_params)):
            if j == 3 or restrained_fitting and j == 1 or j == 0:
                continue

            unit_scaler = [1, freq_scaler][j==1]
            unit_string = ['', '[MHz]'][j==1]
            xlabel = ["Amplitude", "Frequency", "Phase", "Baseline"][j]

            title = xlabel + " residuals"
            title += "\n"
            title += "f: {:.2e}MHz, amp/sigma: {:.2e}".format(f_sine/freq_scaler, sine_amplitude/noise_sigma)
            if noise_band:
                title += " Band ({:.2e},{:.2e})MHz".format(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler)

            fig, ax = plt.subplots()
            ax.set_title(title)
            ax.hist(residuals[:,j]/unit_scaler, density=False, histtype='step', bins='sqrt')
            ax.set_xlabel(xlabel + unit_string)
            ax.set_ylabel("Counts")

            # make it symmetric around 0
            xmax = max(*ax.get_xlim())
            ax.set_xlim(-xmax, xmax)

            if j == 2: # Phase
                xmin, xmax = ax.get_xlim()
                maj_div = max(1, 2**np.ceil(np.log2(np.pi/(xmax-xmin)) + 1 ))
                min_div = maj_div*12

                #axis_pi_ticker(ax.xaxis, major_divider=maj_div, minor_divider=min_div)

        # Plot histogram between phase and frequency
        if True and N > 10:
            fig, ax = plt.subplots()
            title = "Residuals\n"
            title += "f: {:.2e}MHz, amp/sigma: {:.2e}".format(f_sine/freq_scaler, sine_amplitude/noise_sigma)
            if noise_band:
                title += "\n Band ({},{})MHz".format(noise_band[0]/freq_scaler, noise_band[1]/freq_scaler)
            title += ", N={:.1e}".format(N)
            ax.set_title(title)
            ax.set_xlabel('Frequency [MHz]')
            ax.set_ylabel('Phase')
            _, _, _, sc = ax.hist2d(residuals[:,1]/freq_scaler, residuals[:,2], bins=np.sqrt(len(residuals)))
            fig.colorbar(sc, ax=ax, label='Counts')

            #ax.set_xlim(-np.pi, np.pi)
            axis_pi_ticker(ax.yaxis)
            ax.set_ylim(-np.pi, np.pi)

    ## Save or show figures
    save_all_figs_to_path_or_show(args.fname, default_basename=__file__, default_extensions=default_extensions)