"""
Various useful utilities (duh)
"""

import numpy as np
import scipy.fft as ft

def sampled_time(sample_rate=1, start=0, end=1, offset=0):
    return offset + np.arange(start, end, 1/sample_rate)

def rot_vector(phi1=0.12345):
    """
    Return a unit vector rotated by phi radians.
    """

    unit = np.array([
                phi1,
                phi1 - np.pi/2
            ])

    return np.cos(unit)

def detect_edges(threshold, data, rising=True, falling=False):
    """
    Detect rising/falling edges in data, returning the indices
    of the detected edges.

    https://stackoverflow.com/a/50365462
    """

    mask = np.full(len(data)-1, False)

    if rising:
        mask |= (data[:-1] < threshold) & (data[1:] > threshold)

    if falling:
        mask |= (data[:-1] > threshold) & (data[1:] < threshold)

    return np.flatnonzero(mask)+1

def sin_delay(f, t, t_delay=0, phase=0):
    return np.sin( 2*np.pi*f*(t - t_delay) + phase )

def time2phase(time, frequency=1):
    return 2*np.pi*frequency*time

def phase2time(phase, frequency=1):
    return phase/(2*np.pi*frequency)

def phase_modulo(phase, low=np.pi):
    """
    Modulo phase such that it falls within the 
    interval $[-low, 2\pi - low)$.
    """
    return (phase + low) % (2*np.pi) - low

def time_roll(a, samplerate, time_shift, sample_shift=0, int_func=lambda x: np.rint(x).astype(int), **roll_kwargs):
    """
    Like np.roll, but use samplerate and time_shift to approximate
    the offset to roll.
    """
    shift = int_func(time_shift*samplerate + sample_shift)
    return np.roll(a, shift, **roll_kwargs)

### signal generation
def fft_bandpass(signal, band, samplerate):
    """
    Simple bandpassing function employing a FFT.

    Parameters
    ----------
    signal : arraylike
    band : tuple(low, high)
        Frequencies for bandpassing
    samplerate : float
    """
    signal = np.asarray(signal)

    fft = ft.rfft(signal)
    freqs = ft.rfftfreq(signal.size, 1/samplerate)
    fft[(freqs < band[0]) | (freqs > band[1])] = 0

    return ft.irfft(fft, signal.size), (fft, freqs)

def deltapeak(timelength=1e3, samplerate=1, offset=None, peaklength=1):
    """
    Generate a series of zeroes with a deltapeak.

    If offset is not specified, it puts it at a random location.

    Note: the series is regarded as periodic.

    Parameters
    ----------
    timelength : float
    samplerate : float
    offset : float or tuple(float, float)
        Start of the peak
    peaklength : int
        Length of the peak
    """

    N_samples = int(timelength * samplerate)
    if offset is None:
        offset = (None,None)

    if isinstance(offset, (tuple, list)):
        offset_min = 0 if offset[0] is None else offset[0]
        offset_max = N_samples if offset[-1] is None else offset[-1]

        offset = (np.random.random(1)*(offset_max - offset_min)+offset_min).astype(int) % N_samples

    position = (offset + np.arange(0, peaklength)).astype(int) % N_samples

    signal = np.zeros(N_samples)
    signal[position] = 1

    return signal, position