import numpy as np
from collections import namedtuple

from .fft import ft_spectrum

class passband(namedtuple("passband", ['low', 'high'], defaults=[0, np.inf])):
    """
    Band for a bandpass filter.
    It encapsulates a tuple.
    """

    def size():
        return bandsize(self)

    def freq_mask(frequencies):
        return bandpass_mask(frequencies, self)

    def signal_power(samples, samplerate, normalise_bandsize=True, **ft_kwargs):

        return bandpower(samples, samplerate, self, normalise_bandsize, **ft_kwargs)

    def filter_samples(samples, samplerate, **ft_kwargs):
        """
        Bandpass the samples with this passband.
        This is a hard filter.
        """
        fft, freqs = ft_spectrum(samples, samplerate, **ft_kwargs)

        fft[ ~ self.freq_mask(freqs) ] = 0

        return irfft(fft)


def bandpass_samples(samples, samplerate, band=passband(), **ft_kwargs):
        """
        Bandpass the samples with this passband.
        This is a hard filter.
        """
        fft, freqs = ft_spectrum(samples, samplerate, **ft_kwargs)

        fft[ ~ self.freq_mask(freqs) ] = 0

        return np.fft.irfft(fft)

def bandpass_mask(freqs, band=passband()):
    low_pass = abs(freqs) <= band[1]
    high_pass = abs(freqs) >= band[0]

    return low_pass & high_pass

def bandsize(band = passband()):
    return band[1] - band[0]

def bandpower(samples, samplerate=1, band=passband(), normalise_bandsize=True, **ft_kwargs):
    fft, freqs = ft_spectrum(samples, samplerate, **ft_kwargs)

    bandmask = bandpass_mask(freqs, band=band)

    if normalise_bandsize:
        bins = np.count_nonzero(bandmask, axis=-1)
    else:
        bins = 1

    power = np.sum(np.abs(fft[bandmask])**2)

    return power/bins

def signal_to_noise(samples, noise, samplerate=1, signal_band=passband(), noise_band=None):
    if noise_band is None:
        noise_band = signal_band

    if noise is None:
        noise = samples

    noise_power = bandpower(noise, samplerate, noise_band)

    signal_power = bandpower(samples, samplerate, signal_band)

    return (signal_power/noise_power)**0.5