import matplotlib.pyplot as plt
import numpy as np

from scipy import stats
from scipy import special
from scipy import optimize
from itertools import zip_longest

def expectation(x,pdfx):
    dx = x[1]-x[0]
    return np.sum(x*pdfx*dx)

def variance(x,pdfx):
    mu = expectation(x,pdfx)
    dx = x[1]-x[0]
    return np.sum((x**2*pdfx*dx))-mu**2

def random_phase_sum_distribution(theta, sigma, s=1):
    theta = np.asarray(theta)
    ct = np.cos(theta)
    st = np.sin(theta)
    k = s/sigma
    pipi = 2*np.pi
    return (np.exp(-k**2/2)/pipi) + (
        (pipi**-0.5)*k*np.exp(-(k*st)**2/2)) * (
        (1.+special.erf(k*ct*2**-0.5))*ct/2)

def gaussian_phase_distribution(theta, sigma, s=1):
    theta = np.asarray(theta)
    k=s/sigma
    return (2*np.pi)**-0.5*k*np.exp(-(k*theta)**2/2)

def phase_comparison_figure(
        measured_phases,
        true_phases,
        plot_residuals=True,
        f_beacon=None,
        hist_kwargs={},
        sc_kwargs={},
        text_kwargs={},
        colors=['blue', 'orange'],
        legend_on_scatter=True,
        secondary_axis='time',
        fit_gaussian=False,
        fit_randomphasesum=False,
        mean_snr=None,
        return_fit_info=False,
        **fig_kwargs
        ):
    """
    Create a figure comparing measured_phase against true_phase
    by both plotting the values, and the residuals.
    """
    default_fig_kwargs = dict(sharex=True)
    default_hist_kwargs = dict(bins='sqrt', density=False, alpha=0.8, histtype='step')
    default_text_kwargs = dict(fontsize=14, verticalalignment='top')
    default_sc_kwargs = dict(alpha=0.6, ls='none')

    do_hist_plot = hist_kwargs is not False
    if hist_kwargs is False:
        hist_kwargs = {}

    do_scatter_plot = sc_kwargs is not False
    if sc_kwargs is False:
        sc_kwargs = {}

    fig_kwargs = {**default_fig_kwargs, **fig_kwargs}
    hist_kwargs = {**default_hist_kwargs, **hist_kwargs}
    text_kwargs = {**default_text_kwargs, **text_kwargs}
    sc_kwargs = {**default_sc_kwargs, **sc_kwargs}

    fig, axs = plt.subplots(0+do_hist_plot+do_scatter_plot, 1, **fig_kwargs)

    if not hasattr(axs, '__len__'):
        axs = [axs]

    if f_beacon and secondary_axis in ['phase', 'time']:
        phase2time = lambda x: x/(2*np.pi*f_beacon)
        time2phase = lambda x: 2*np.pi*x*f_beacon

        if secondary_axis == 'time':
            functions = (phase2time, time2phase)
            label = 'Time $\\varphi/(2\\pi f_{beac})$ [ns]'
        else:
            functions = (time2phase, phase2time)
            label = 'Phase $2\\pi t f_{beac}$ [rad]'

        secax = axs[0].secondary_xaxis('top', functions=functions)

    # Histogram
    fit_info = {}
    if do_hist_plot:
        i=0

        axs[i].set_ylabel("#")

        this_kwargs = dict(
                ax = axs[i],
                text_kwargs=text_kwargs,
                hist_kwargs={**hist_kwargs, **dict(label='Measured', color=colors[0], ls='solid')},
                mean_snr=mean_snr,
                fit_distr=[],
                )

        if fit_gaussian:
            this_kwargs['fit_distr'].append('gaussian')

        if fit_randomphasesum:
            this_kwargs['fit_distr'].append('randomphasesum')

        _, fit_info = fitted_histogram_figure(
                measured_phases,
                **this_kwargs
                )

        if not plot_residuals: # also plot the true clock phases
            _bins = fit_info['bins']
            axs[i].hist(true_phases, color=colors[1], label='Actual', ls='dashed', **{**hist_kwargs, **dict(bins=_bins)})

    # Scatter plot
    if do_scatter_plot:
        i=1
        axs[i].set_ylabel("Antenna no.")
        axs[i].plot(measured_phases, np.arange(len(measured_phases)), marker='x' if plot_residuals else '3', color=colors[0], label='Measured', **sc_kwargs)

        if not plot_residuals: # also plot the true clock phases
            axs[i].plot(true_phases, np.arange(len(true_phases)), marker='4', color=colors[1], label='Actual', **sc_kwargs)

        if not plot_residuals and legend_on_scatter:
            axs[i].legend()

    fig.tight_layout()

    if return_fit_info:
        return fig, fit_info

    return fig


def fitted_histogram_figure(
        amplitudes,
        fit_distr = None,
        calc_chisq = True,
        text_kwargs={},
        hist_kwargs={},
        mean_snr = None,
        ax = None,
        **fig_kwargs
    ):
    """
    Create a figure showing $amplitudes$ as a histogram.
    If fit_distr is a (list of) string, also fit the respective
    distribution function and show the parameters on the figure.
    """
    default_hist_kwargs = dict(bins='sqrt', density=False, alpha=0.8, histtype='step', label='hist')
    default_text_kwargs = dict(fontsize=14, verticalalignment='top')

    if isinstance(fit_distr, str):
        fit_distr = [fit_distr]

    hist_kwargs = {**default_hist_kwargs, **hist_kwargs}
    text_kwargs = {**default_text_kwargs, **text_kwargs}

    if ax is None:
        fig, ax = plt.subplots(1, 1, **fig_kwargs)
    else:
        fig = ax.get_figure()

    text_kwargs['transform'] = ax.transAxes

    counts, bins, _patches = ax.hist(amplitudes, **hist_kwargs)

    fit_info = []
    if fit_distr:
        min_x = min(amplitudes)
        max_x = max(amplitudes)

        bin_centers = bins[:-1] + np.diff(bins) / 2

        dx = bins[1] - bins[0]
        scale = len(amplitudes) * dx

        xs = np.linspace(min_x, max_x)

        for distr in fit_distr:
            fit_params2text_params = lambda x: x
            fit_ys = None
            fit_params = None
            cdf = None

            if 'rice' == distr:
                name = "Rice"
                param_names = [ "$\\nu$", "$\\sigma$" ]
                distr_func = stats.rice

                fit_params2text_params = lambda x: (x[0]*x[1], x[1])

            elif 'gaussian' == distr:
                name = "Norm"
                param_names = [ "$\\mu$", "$\\sigma$" ]
                distr_func = stats.norm

            elif 'rayleigh' == distr:
                name = "Rayleigh"
                param_names = [ "$\\sigma$" ]
                distr_func = stats.rayleigh

                fit_params2text_params = lambda x: (x[0]+x[1]/2,)

            elif 'randomphasesum' == distr:
                name = "RandPhaseS"
                param_names = [ "$\\sigma$", 's']
                pdf = random_phase_sum_distribution

                bounds = ((0,0.9999), (np.inf,1))
                fit_params, pcov = optimize.curve_fit(pdf, bin_centers, counts, bounds=bounds)
                fit_ys = pdf( xs, *fit_params)

                fit_params2text_params = lambda x: (x[1], x[0])

            elif 'gaussphase' == distr:
                name = 'GaussPhase'
                param_names = [ "$\\sigma$", 's']
                pdf = gaussian_phase_distribution


                bounds = ((0,0.9999), (np.inf,1))
                fit_params, pcov = optimize.curve_fit(pdf, bin_centers, counts, bounds=bounds)
                fit_ys = pdf( xs, *fit_params)

                fit_params2text_params = lambda x: (x[1], x[0])

            else:
                raise ValueError('Unknown distribution function '+distr)

            label = name +"(" + ','.join(param_names) + ')'

            if fit_ys is None:
                fit_params = distr_func.fit(amplitudes)
                fit_ys = scale * distr_func.pdf(xs, *fit_params)
                cdf = distr_func.cdf

            ax.plot(xs, fit_ys, label=label)

            chisq_strs = []
            if calc_chisq and cdf:
                ct = np.diff(cdf(bins, *fit_params))*np.sum(counts)
                if True: # stabilise the chisquare derivation
                    ct *= np.sum(counts)/np.sum(ct)
                c2t = stats.chisquare(counts, ct, ddof=len(fit_params))
                chisq_strs = [
                        f"$\\chi^2$/dof = {c2t[0]: .2g}/{len(fit_params)}"
                        ]

            # change parameters if needed
            text_fit_params = fit_params2text_params(fit_params)

            text_str = "\n".join(
                [label]
                +
                [ f"{param} = {value: .2e}" for param, value in zip_longest(param_names, text_fit_params, fillvalue='?') ]
                +
                chisq_strs
                )

            this_info = {
                    'name': name,
                    'param_names': param_names,
                    'param_values': text_fit_params,
                    'text_str': text_str,
                }

            if chisq_strs:
                this_info['chisq'] = c2t[0]
                this_info['dof'] = len(fit_params)

            fit_info.append(this_info)

        loc = (0.02, 0.95)
        ax.text(*loc, "\n\n".join([info['text_str'] for info in fit_info]), **{**text_kwargs, **dict(ha='left')})

    if mean_snr:
        text_str = f"$\\langle SNR \\rangle$ = {mean_snr: .1e}"
        loc = (0.98, 0.95)
        ax.text(*loc, text_str, **{**text_kwargs, **dict(ha='right')})

    return fig, dict(fit_info=fit_info, counts=counts, bins=bins)