"""
Functions to simplify plotting of fourier spectra
"""
# vim: fdm=indent ts=4

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

from .fft import ft_spectrum

def plot_spectrum(
        spectrum, freqs,
        plot_complex=False, plot_power=False, plot_amplitude=None,
        freq_unit="Hz", freq_scaler=1,
        title='Spectrum', xlabel='Frequency', ax=None,
        **plot_kwargs
        ):
    """ Plot a signal's spectrum on an Axis object"""
    plot_amplitude = plot_amplitude or (not plot_power and not plot_complex)
    alpha = 1

    if ax is None:
        ax = plt.gca()

    ax.set_title(title)
    ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
    ylabel = ""
    if plot_amplitude or plot_complex:
        ylabel = "Amplitude"
    if plot_power:
        if ylabel:
            ylabel += "|"
        ylabel += "Power"
    ax.set_ylabel(ylabel)

    if plot_complex:
        alpha = 0.5
        ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha, **plot_kwargs)
        ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha, **plot_kwargs)

    if plot_power:
        ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha, **plot_kwargs)

    if plot_amplitude:
        ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha, **plot_kwargs)

    ax.legend()

    return ax

def plot_phase(
        spectrum, freqs,
        ylim_epsilon=0.5, ax=None, grid=True,
        freq_unit="Hz", freq_scaler=1, xlabel='Frequency',
        major_divider=2, minor_divider=12,
        **plot_kwargs
    ):
    """
    Plot the phase of spectrum
    """

    if ax is None:
        ax = plt.gca()

    ax.set_ylabel("Phase")
    ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
    ax.grid(grid)

    ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-', **plot_kwargs)
    ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon)

    axis_pi_ticker(ax.yaxis, major_divider=major_divider, minor_divider=minor_divider)

    return ax

def axis_pi_ticker(axis, major_divider=2, minor_divider=12):

    major_tick = Multiple(major_divider, np.pi, '\pi')
    minor_tick = Multiple(minor_divider, np.pi, '\pi')

    axis.set_major_locator(major_tick.locator())
    axis.set_major_formatter(major_tick.formatter())
    axis.set_minor_locator(minor_tick.locator())

    return axis

def plot_signal(
        signal, sample_rate = 1,
        time=None, ax=None,
        title='Signal', ylabel='Amplitude',
        time_unit="s", xlabel='Time',
        **kwargs):
    """
    Plot the signal in the time domain
    """

    if ax is None:
        ax = plt.gca()

    if time is None:
        time = np.arange(len(signal))/sample_rate

    ax.set_title(title)
    ax.set_xlabel(xlabel + (" ["+time_unit+"]" if time_unit else "" ))
    ax.set_ylabel(ylabel)

    ax.plot(time, signal, **kwargs)

    return ax

def plot_combined_spectrum(
        spectrum, freqs,
        fig=None, gs=None,
        spectrum_ax=None, phase_ax=None,
        spectrum_kwargs={}, phase_kwargs={},
        **shared_kwargs
        ):
    """
    Plot both the frequencies and phase in one figure.
    """

    ax1, ax2 = None, None
    if spectrum_ax is not None or phase_ax is not None:
        if spectrum_ax is not None:
            ax1 = spectrum_ax
        if phase_ax is not None:
            ax2 = phase_ax

    else:
        # configure plotting layout
        if fig is None:
            fig = plt.figure(figsize=(8, 16))

        if gs is None:
            gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0)

        ax1 = fig.add_subplot(gs[:-1, -1])
        ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1)

    axes = np.array([ax1, ax2])

    # plot the spectrum
    plot_spectrum(spectrum, freqs, ax=ax1, **{ **shared_kwargs, **spectrum_kwargs})

    # plot the phase
    plot_phase(spectrum, freqs, ax=ax2, **{ **shared_kwargs, **phase_kwargs})

    ax1.xaxis.tick_top()
    [label.set_visible(False) for label in ax1.get_xticklabels()]

    return fig, axes

def plot_signal_and_spectrum(
        signal, sample_rate=1, title=None,
        signal_kwargs={}, ft_kwargs={},
        spectrum_kwargs={}, phase_kwargs={'major_divider':1, 'minor_divider': 6},
        **phase_spectrum_kwargs
        ):
    """
    Create a figure showing both the signal and the combined spectrum.
    """
    fig = plt.figure(figsize=(16, 4))

    if title:
        fig.suptitle(title)

    # setup plot layout
    gs0 = gridspec.GridSpec(1, 2, figure=fig)
    gs00 = gs0[0].subgridspec(1, 1)
    gs01 = gs0[1].subgridspec(2, 1, height_ratios=[3,1], hspace=0)

    # plot the signal
    ax1 = fig.add_subplot(gs00[0, 0])
    plot_signal(signal, sample_rate, ax=ax1, **signal_kwargs)

    # plot spectrum
    signal_fft, freqs = ft_spectrum(signal, sample_rate, **ft_kwargs)
    _, (ax2, ax3) = plot_combined_spectrum(
            signal_fft, freqs,
            fig=fig, gs=gs01,
            spectrum_kwargs=spectrum_kwargs, phase_kwargs=phase_kwargs,
            **phase_spectrum_kwargs
            )

    # return the axes
    axes = np.array([ax1, ax2, ax3])

    return fig, axes

def multiple_formatter(denominator=2, number=np.pi, latex='\pi'):
    """
    From https://stackoverflow.com/a/53586826
    """
    def gcd(a, b):
        while b:
            a, b = b, a%b
        return a

    def _multiple_formatter(x, pos):
        den = denominator
        num = np.int(np.rint(den*x/number))
        com = gcd(num,den)
        (num,den) = (int(num/com),int(den/com))
        if den==1:
            if num==0:
                return r'$0$'
            if num==1:
                return r'$%s$'%latex
            elif num==-1:
                return r'$-%s$'%latex
            else:
                return r'$%s%s$'%(num,latex)
        else:
            if num==1:
                return r'$\frac{%s}{%s}$'%(latex,den)
            elif num==-1:
                return r'$\frac{-%s}{%s}$'%(latex,den)
            else:
                return r'$\frac{%s%s}{%s}$'%(num,latex,den)
    return _multiple_formatter

class Multiple:
    """
    From https://stackoverflow.com/a/53586826
    """
    def __init__(self, denominator=2, number=np.pi, latex='\pi'):
        self.denominator = denominator
        self.number = number
        self.latex = latex

    def locator(self):
        return plt.MultipleLocator(self.number / self.denominator)

    def formatter(self):
        return plt.FuncFormatter(multiple_formatter(self.denominator, self.number, self.latex))