#!/usr/bin/env python3

import numpy as np
import scipy.interpolate as interp

if __name__ == "__main__" and __package__ is None:
    import sys
    sys.path.append("../../")
    __package__ = "lib.signals"

from .signal import *

class DigitisedSignal(Signal):
    """
    Model an arbitrary digitised signal that can be translated to another position and time.
    """

    def __init__(self, signal, sample_rate, t_0 = 0, x_0 = 0, periodic=True, interp1d_kw = None, velocity=None, t_f = None, x_f = None):
        """
        Initialise by saving the raw signal

        Parameters
        ----------
        signal : arraylike
            The raw signal to wrap.
        sample_rate : float
            Sample rate of the raw signal.
        t_0 : float, optional
            Time that this signal is sent out.
        x_0 : float, optional
            Location that this signal is sent out from.
        periodic : bool, optional
            Translated signal is 0 if it is not periodic
            and the time/distance is outside the samples.
        interp1d_kw : bool or dict, optional
            Use scipy.interpolate's interp1d_kw for interpolation.
            Set to True, or a dictionary to enable.
            Dictionary will be entered in as **kwargs.
        velocity : float, optional
            Defaults to the speed of light in m/s.
        t_f : float, optional
            Default time that this signal is received.
        x_f : float, optional
            Default Location that this signal is received.
        """
        super().__init__(t_0=t_0, x_0=x_0, velocity=velocity, t_f=t_f, x_f=x_f)

        self.raw = np.asarray(signal)
        self.periodic = periodic

        self.sample_rate = sample_rate # Hz
        self.sample_length = len(self.raw)
        self.time_length = self.sample_length/sample_rate # s

        # choose interpolation method
        if not interp1d_kw:
            self.interp_f = None

        # offload interpolation to scipy.interpolate
        else:
            interp1d_kw_defaults = {
                "copy": False,
                "kind": 'linear',
                "assume_sorted": True,
                "bounds_error": True
            }

            if self.periodic:
                interp1d_kw_defaults['bounds_error'] = False
                interp1d_kw_defaults['fill_value'] = (self.raw[-1], self.raw[0])

            # merge kwargs
            if interp1d_kw is not True:
                interp1d_kw = { **interp1d_kw_defaults, **interp1d_kw }

            self.interp_f = interp.interp1d(
                                np.arange(0, self.sample_length),
                                self.raw,
                                **interp1d_kw
                            )

    def __len__(self):
        return self.sample_length

    def raw_time(self):
        return np.arange(0, self.time_length, 1/self.sample_rate)

    def _translate(self, t_f = None, x_f = None, t_0 = None, x_0 = None, velocity = None):
        """
        Translate the signal from (t_0, x_0) to (t_f, x_f) with optional velocity.

        Returns the signal at (t_f, x_f) and the total time offset
        """

        total_time_offset = self.total_time_offset(t_f=t_f, x_f=x_f, t_0=t_0, x_0=x_0, velocity=velocity)
        n_offset = (total_time_offset * self.sample_rate )

        # periodic signal
        if self.periodic:
            n_offset = n_offset % self.sample_length

        # non-periodic and possibly outside the bounds
        else:
            # this is a numpy array
            if hasattr(n_offset, 'ndim') and n_offset.ndim > 0:
                mask_idx = np.nonzero( (0 > n_offset) | (n_offset >= self.sample_length) )

                n_offset[mask_idx] = 0

            # not a numpy array
            else:
                # outside the bounds
                if 0 > n_offset or n_offset > self.sample_length:
                    n_offset = np.nan

        # n_offset is invalid
        # set amplitude to zero
        if n_offset is np.nan:
            amplitude = 0

        # n_offset is valid, interpolate the amplitude
        else:
            # offload to scipy interpolation
            if self.interp_f:
                amplitude = self.interp_f(n_offset)

            # self written linear interpolation
            else:
                n_offset = np.asarray(n_offset)

                n_offset_eps, n_offset_int = np.modf(n_offset)
                n_offset_int = n_offset.astype(int)

                if True:
                    amplitude = (1-n_offset_eps) * self.raw[n_offset_int] \
                                + n_offset_eps * self.raw[(n_offset_int + 1) % self.sample_length]

                # use nearest value instead of interpolation
                else:
                    amplitude = self.raw[n_offset_int]

            if not self.periodic:
                if hasattr(amplitude, 'ndim') and amplitude.ndim > 0:
                    amplitude[mask_idx] = 0

        return amplitude, total_time_offset

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    from scipy.stats import norm

    sample_rate = 3e2 # Hz

    t_offset = 8
    periodic = False

    time   = t_offset + np.arange(0, 1, 1/sample_rate) #s
    time2  = t_offset + np.arange(-1.5, 1, 1/sample_rate) #s

    signal = norm.pdf(time, time[len(time)//2], (time[-1] - time[0])/10)

    mysignal = DigitisedSignal(signal, sample_rate, t_0 = t_offset, periodic=True)
    mysignal2 = DigitisedSignal(signal, sample_rate, t_0 = t_offset, periodic=False)

    fig, ax = plt.subplots(1, 1, figsize=(16,4))
    ax.set_title("Raw and DigitisedSignal")
    ax.set_ylabel("Amplitude")
    ax.set_xlabel("Time")

    ax.plot(time,   signal,                     label='Raw signal')
    ax.plot(time2,  mysignal(time2) +0.5, '.-', label='DigitisedSignal(periodic)+0.5')
    ax.plot(time2,  mysignal2(time2)-0.5, '.-', label='DigitisedSignal-0.5')

    ax.legend()

    plt.show();