"""
Simple FFT stuff
"""

import numpy as np
import scipy.fftpack as ft

def get_freq_spec(val,dt):
    """From earsim/tools.py"""
    fval = np.fft.fft(val)[:len(val)//2]
    freq =  np.fft.fftfreq(len(val),dt)[:len(val)//2]
    return fval, freq


def ft_spectrum( signal, sample_rate=1, ftfunc=None, freqfunc=None, mask_bias=False, normalise_amplitude=False):
    """Return a FT of $signal$, with corresponding frequencies"""

    if True:
        return get_freq_spec(signal, 1/sample_rate)

    n_samples = len(signal)

    if ftfunc is None:
        real_signal = np.isrealobj(signal)
        if False and real_signal:
            ftfunc = ft.rfft
            freqfunc = ft.rfftfreq
        else:
            ftfunc = ft.fft
            freqfunc = ft.fftfreq

    if freqfunc is None:
        freqfunc = ft.fftfreq

    normalisation = 2/len(signal) if normalise_amplitude else 1

    spectrum = normalisation * ftfunc(signal)
    freqs = freqfunc(n_samples, 1/sample_rate)

    if not mask_bias:
        return spectrum, freqs
    else:
        return spectrum[1:], freqs[1:]

def ft_corr_vectors(freqs, time):
    """
    Get the cosine and sine terms for freqs at time.

    Takes the outer product of freqs and time.
    """
    freqtime = np.outer(freqs, time)

    c_k = np.cos(2*np.pi*freqtime)
    s_k = np.sin(2*np.pi*freqtime)

    return c_k, s_k

def direct_fourier_transform(freqs, time, samplesets_iterable):
    """
    Determine the fourier transform of each sampleset in samplesets_iterable at freqs.

    The samplesets are expected to have the same time vector.

    Returns either a generator to return the fourier transform for each sampleset
    if samplesets_iterable is a generator
    or a numpy array.
    """

    c_k, s_k = ft_corr_vectors(freqs, time)

    if not hasattr(samplesets_iterable, '__len__') and hasattr(samplesets_iterable, '__iter__'):
        # samplesets_iterable is an iterator
        # return an iterator containing (real, imag) amplitudes
        return ( (np.dot(c_k, samples), np.dot(s_k, samples)) for samples in samplesets_iterable )

    # Numpy array
    return np.dot(c_k, samplesets_iterable), np.dot(s_k, samplesets_iterable)



def discrete_fourier_properties(samples, samplerate):
    """
    Return f_delta and f_nyquist.
    """
    return (samplerate/(len(samples)), samplerate/2)