"""
Routines needed to analyse a beacon signal
"""
import numpy as np
from scipy import signal

# monkey patch correlation_lags into signal if it does not exist
if not hasattr(signal, 'correlation_lags'):
    def correlation_lags(in1_len, in2_len, mode='full'):
        r"""
        Calculates the lag / displacement indices array for 1D cross-correlation.
        Parameters
        ----------
        in1_size : int
            First input size.
        in2_size : int
            Second input size.
        mode : str {'full', 'valid', 'same'}, optional
            A string indicating the size of the output.
            See the documentation `correlate` for more information.
        See Also
        --------
        correlate : Compute the N-dimensional cross-correlation.
        Returns
        -------
        lags : array
            Returns an array containing cross-correlation lag/displacement indices.
            Indices can be indexed with the np.argmax of the correlation to return
            the lag/displacement.
        Notes
        -----
        Cross-correlation for continuous functions :math:`f` and :math:`g` is
        defined as:
        .. math::
            \left ( f\star g \right )\left ( \tau \right )
            \triangleq \int_{t_0}^{t_0 +T}
            \overline{f\left ( t \right )}g\left ( t+\tau \right )dt
        Where :math:`\tau` is defined as the displacement, also known as the lag.
        Cross correlation for discrete functions :math:`f` and :math:`g` is
        defined as:
        .. math::
            \left ( f\star g \right )\left [ n \right ]
            \triangleq \sum_{-\infty}^{\infty}
            \overline{f\left [ m \right ]}g\left [ m+n \right ]
        Where :math:`n` is the lag.
        Examples
        --------
        Cross-correlation of a signal with its time-delayed self.
        >>> from scipy import signal
        >>> from numpy.random import default_rng
        >>> rng = default_rng()
        >>> x = rng.standard_normal(1000)
        >>> y = np.concatenate([rng.standard_normal(100), x])
        >>> correlation = signal.correlate(x, y, mode="full")
        >>> lags = signal.correlation_lags(x.size, y.size, mode="full")
        >>> lag = lags[np.argmax(correlation)]
        """

        # calculate lag ranges in different modes of operation
        if mode == "full":
            # the output is the full discrete linear convolution
            # of the inputs. (Default)
            lags = np.arange(-in2_len + 1, in1_len)
        elif mode == "same":
            # the output is the same size as `in1`, centered
            # with respect to the 'full' output.
            # calculate the full output
            lags = np.arange(-in2_len + 1, in1_len)
            # determine the midpoint in the full output
            mid = lags.size // 2
            # determine lag_bound to be used with respect
            # to the midpoint
            lag_bound = in1_len // 2
            # calculate lag ranges for even and odd scenarios
            if in1_len % 2 == 0:
                lags = lags[(mid-lag_bound):(mid+lag_bound)]
            else:
                lags = lags[(mid-lag_bound):(mid+lag_bound)+1]
        elif mode == "valid":
            # the output consists only of those elements that do not
            # rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
            # must be at least as large as the other in every dimension.

            # the lag_bound will be either negative or positive
            # this let's us infer how to present the lag range
            lag_bound = in1_len - in2_len
            if lag_bound >= 0:
                lags = np.arange(lag_bound + 1)
            else:
                lags = np.arange(lag_bound, 1)
        return lags

    signal.correlation_lags = correlation_lags

##### end of monkey patch correlation_lags

def beacon_time_delay(samplerate, ref_beacon, beacon):
    """
    Determine the time delay between two beacons using correlation.

    """
    grid = correlation_grid(in1_len=len(ref_beacon), in2_len=len(beacon), mode='full')
    time_lag, errs = lag_gridsearch(grid, samplerate, ref_beacon, beacon)

    return time_lag, errs

def beacon_phase_delay(samplerate, f_beacon, ref_beacon, beacon):
    """
    Determine total phase delay between two beacons using correlation.

    Internally uses beacon_time_delay.
    """
    time_delay, errs = beacon_time_delay(samplerate, ref_beacon, beacon)

    phase = 2*np.pi*f_beacon*time_delay
    phase_err = 2*np.pi*f_beacon*errs
    
    return phase, phase_err

def beacon_integer_period(samplerate, f_beacon, ref_impulse, impulse, k_step=1):
    return _beacon_integer_period_sum(samplerate, f_beacon, ref_impulse, impulse, k_step=k_step)

def _beacon_integer_period_sum(samplerate, f_beacon, ref_impulse, impulse, k_step=1):
    """
    Use the maximum of a coherent sum to determine
    the best number of periods of f_beacon.
    """
    max_k = int( len(ref_impulse)*f_beacon/samplerate )
    ks = np.arange(0, max_k, step=k_step)

    maxima = np.empty(len(ks))

    best_i = 0

    for i,k in enumerate(ks, 0):
        augmented_impulse = util.time_roll(impulse, samplerate, k/f_beacon)

        maxima[i] = max(ref_impulse + augmented_impulse)

        if maxima[i] > maxima[best_i]:
            best_i = i

    return ks[best_i], (ks, maxima)


def lag_gridsearch(grid, sample_rate, reference, signal_data):
    """
    Return the best time shift found when doing a grid search.

    Parameters
    ----------
    lag_grid - ndarray
        The array specifying the grid that is to be searched.
    sample_rate - float
        Sample rate of signal_data to transform index to time.
    signal_data - ndarray
        The real signal to find the time shift for.
    reference - ndarray
        Real signal to use as reference to obtain lag.

    Returns
    -------
    lag : ndarray
        The best time shift obtained
    err : tuple
        Difference to the previous and next time shift from lag, resp.
    """

    assert signal_data.shape >= reference.shape, str(signal_data.shape) + " " + str(reference.shape)

    corrs = grid_correlate(grid, reference, signal_data)

    idx = np.argmax(corrs)

    lag = grid[idx]/sample_rate

    err_min =  (grid[idx-1]-grid[idx])/(2*sample_rate)
    err_plus = (grid[idx+1]-grid[idx])/(2*sample_rate)

    return lag, np.array([err_min, err_plus])


def grid_correlate(grid, reference, x):
    """
    Determine correlation between x and reference using grid as
    the lags to be used for the correlation.

    Parameters
    ----------
    grid - ndarray
        The array specifying the grid that is to be searched.
    x - ndarray
        The real signal to find the time shift for.
    reference - ndarray
        Real signal to use as reference to obtain lag.

    Returns
    -------
    corrs - ndarray
        The correlations along grid.
    """
    grid = np.asarray(grid)
    x = np.asarray(x)
    reference = np.asarray(reference)

    assert x.shape >= reference.shape, str(signal_data.shape) + " " + str(reference.shape)

    reference = np.pad(reference, (0,len(x)-len(reference)), 'constant', constant_values=0)

    ref_conj = np.conjugate(reference)

    corrs = np.array([np.dot(np.roll(ref_conj, lag), x) for lag in grid], dtype=np.float64)

    return corrs

def correlation_grid(grid_size=None, in1_len=None, in2_len = None, end = None, start=None, mode='full'):
    """
    Abuse correlation_lags to determine the endpoints of the grid.
    """

    if in1_len is not None or in2_len is not None:
        if in2_len is None:
            in2_len = in1_len
        elif in1_len is None:
            in1_len = in2_len

        lags = signal.correlation_lags(in1_len, in2_len, mode=mode)

        max_lag = max(lags)
        min_lag = min(lags)
    else:
        max_lag = np.inf
        min_lag = -np.inf

    if end is None:
        end = max_lag
    elif end > max_lag:
        raise ValueError("Grid end is too high")

    if start is None:
        start = min_lag
    elif start < min_lag:
        raise ValueError("Grid start is too low")

    if grid_size is None:
        grid_size = end - start

    return np.linspace(start, end, grid_size, dtype=int, endpoint=False)