"""
Define the super Signal class
"""

import numpy as np

class Signal():
    """
    An arbitrary signal that can be translated to another position and time.
    Note that position can be of any length.

    Super object, cannot be used directly.
    """
    def __init__(self, t_0 = 0, x_0 = 0, velocity=None, t_f = None, x_f = None):
        """
        Parameters
        ----------
        t_0 : float, optional
            Time that this signal is sent out.
        x_0 : float, optional
            Location that this signal is sent out from.
        velocity : float, optional
            Defaults to the speed of light in m/s.
        t_f : float, optional
            Default time that this signal is received.
        x_f : float, optional
            Default Location that this signal is received.
        """
        if t_0 is None:
            raise ValueError("t_0 cannot be None")
        if x_0 is None:
            raise ValueError("x_0 cannot be None")

        self.x_0 = np.asarray(x_0) # m
        self.t_0 = np.asarray(t_0) # s

        self.velocity = 299792458 if velocity is None else velocity # m / s

        # Default final positions
        t_f = np.asarray(t_f) if t_f is not None else None
        x_f = np.asarray(x_f) if x_f is not None else None

        self.x_f = x_f
        self.t_f = t_f

    def __call__(self, t_f = None, x_f = None, **kwargs):
        """
        Allow this class to be used as a function.
        """
        return self._translate(t_f, x_f, **kwargs)[0]

    def _translate(self, t_f = None, x_f = None, t_0 = None, x_0 = None, velocity = None):
        """
        Translate the signal from (t_0, x_0) to (t_f, x_f) with optional velocity.

        Returns the signal at (t_f, x_f)
        """

        raise NotImplementedError

    def spatial_time_offset(self, x_f=None, x_0=None, velocity=None):
        """
        Calculate the time offset caused by a spatial distance.
        """
        if velocity is None:
            velocity = self.velocity

        if x_0 is None:
            x_0 = self.x_0
        if x_f is None:
            x_f = self.x_f

        ## make sure they are arrays
        x_0 = np.asarray(x_0) if x_0 is not None else None
        x_f = np.asarray(x_f) if x_f is not None else None

        return np.sqrt( np.sum((x_f - x_0)**2, axis=-1) )/velocity

    def temporal_time_offset(self, t_f=None, t_0=None):
        """
        Calculate the time offset caused by a temporal distance.
        """
        if t_0 is None:
            t_0 = self.t_0
        if t_f is None:
            t_f = self.t_f

        ## make sure they are arrays
        t_0 = np.asarray(t_0) if t_0 is not None else None
        t_f = np.asarray(t_f) if t_f is not None else None

        return t_f - t_0


    def total_time_offset(self, t_f = None, x_f = None, t_0 = None, x_0 = None, velocity = None):
        """
        Calculate how much time shifting is needed to go from (t_0, x_0) to (t_f, x_f).

        Convention:
            (t_0, x_0) < (t_f, x_0) gives a positive time shift,
            (t_0, x_0) != (t_0, x_f) gives a negative time shift

        Returns:
         the time shift
        """
        # Get default values
        ## starting point
        if t_0 is None:
            t_0 = self.t_0
        if x_0 is None:
            x_0 = self.x_0

        ## final point
        if x_f is None:
            x_f = self.x_f
        if t_f is None:
            t_f = self.t_f

        ## make sure they are arrays
        t_0 = np.asarray(t_0) if t_0 is not None else None
        x_0 = np.asarray(x_0) if x_0 is not None else None
        t_f = np.asarray(t_f) if t_f is not None else None
        x_f = np.asarray(x_f) if x_f is not None else None

        # spatial offset
        if x_f is None:
            spatial_time_offset = 0
        else:
            spatial_time_offset = self.spatial_time_offset(x_f, x_0=x_0, velocity=velocity)

        # temporal offset
        if t_f is None:
            temporal_time_offset = 0
        else:
            temporal_time_offset = self.temporal_time_offset(t_f, t_0=t_0)

        return temporal_time_offset - spatial_time_offset