#!/usr/bin/env python3
# vim: fdm=indent ts=4

"""
Show Signal to noise for the original simulation signal,
the beacon signal and the combined signal for each antenna
"""

import numpy as np
import h5py
import matplotlib.pyplot as plt
import numpy as np

from earsim import REvent, block_filter
import aa_generate_beacon as beacon
import lib

if __name__ == "__main__":
    from os import path
    import sys
    import matplotlib
    import os
    if os.name == 'posix' and "DISPLAY" not in os.environ:
        matplotlib.use('Agg')

    from scriptlib import MyArgumentParser
    parser = MyArgumentParser()

    # Bandpass
    parser.add_argument('-p', '--use-passband', type=bool, default=True, help='(Default: %(default)d)')
    parser.add_argument('-l', '--passband-low',  type=float, default=30e-3, help='Lower frequency [GHz] of the passband filter. (set -1 for np.inf) (Default: %(default)d)')
    parser.add_argument('-u', '--passband-high', type=float, default=80e-3, help='Upper frequency [GHz] of the passband filter. (set -1 for np.inf) (Default: %(default)d)')

    args = parser.parse_args()

    figsize = (12,8)

    fig_dir = args.fig_dir
    show_plots = args.show_plots

    ####
    fname_dir = args.data_dir
    antennas_fname = path.join(fname_dir, beacon.antennas_fname)
    tx_fname = path.join(fname_dir, beacon.tx_fname)
    beacon_snr_fname = path.join(fname_dir, beacon.beacon_snr_fname)
    airshower_snr_fname = path.join(fname_dir, beacon.airshower_snr_fname)

    # create fig_dir
    if fig_dir:
        os.makedirs(fig_dir, exist_ok=True)

    # Read in antennas from file
    f_beacon, tx, antennas = beacon.read_beacon_hdf5(antennas_fname, traces_key='filtered_traces')
    _, __, txdata = beacon.read_tx_file(tx_fname)

    # Read zeropadded traces
    _, __, signal_antennas = beacon.read_beacon_hdf5(antennas_fname, traces_key='original_E_AxB', read_AxB=False )
    # !!HACK!! Repack traces in signal_antennas to antennas
    for i, ant in enumerate(signal_antennas):
        if antennas[i].name != ant.name:
            print("Error!")
            import sys
            sys.exit()

        antennas[i].orig_E_AxB = ant.Ex

    # general properties
    dt = antennas[0].t[1] - antennas[0].t[0] # ns
    beacon_pb = lib.passband(f_beacon, None) # GHz
    beacon_amp = np.max(txdata['amplitudes'])# mu V/m

    # General Bandpass
    low_bp  = args.passband_low  if args.passband_low  >= 0 else np.inf # GHz
    high_bp = args.passband_high if args.passband_high >= 0 else np.inf # GHz
    pb = lib.passband(low_bp, high_bp) # GHz

    noise_pb = pb

    if args.use_passband: # Apply filter to raw beacon/noise to compare with Filtered Traces
        myfilter = lambda x: block_filter(x, dt, pb[0], pb[1])
    else: # Compare raw beacon/noise with Filtered Traces
        myfilter = lambda x: x

    ##
    ## Debug plot of Beacon vs Noise SNR
    ##
    if True:
        ant = antennas[0]

        fig, ax = plt.subplots(figsize=figsize)
        _debug_snrs = lib.signal_to_noise(myfilter(beacon_amp*ant.beacon), myfilter(ant.noise), samplerate=1/dt, signal_band=beacon_pb, noise_band=noise_pb, debug_ax=ax, mode='sine')

        ax.legend(title="$\\langle SNR \\rangle$ = {: .1e}".format(np.mean(_debug_snrs)))

        ax.set_title("Spectra and passband")
        ax.set_xlabel("Frequency [GHz]")
        ax.set_ylabel("Amplitude")
        low_x, high_x = min(beacon_pb[0], noise_pb[0]), max(beacon_pb[1] or 0, noise_pb[1])
        ax.set_xlim(low_x, high_x)

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".beacon_vs_noise_snr.debug_plot.pdf"))

    ##
    ## Beacon vs Noise SNR
    ##
    if True:
        N_samples = len(antennas[0].beacon)
        beacon_snrs = [ lib.signal_to_noise(myfilter(beacon_amp*ant.beacon), myfilter(ant.noise), samplerate=1/dt, signal_band=beacon_pb, noise_band=noise_pb, mode='sine') for i, ant in enumerate(antennas) ]

        # write mean and std to file
        beacon.write_snr_file(beacon_snr_fname, beacon_snrs)

        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title(f"Maximum Beacon/Noise SNR (N_samples:{N_samples:.1e})")
        ax.set_xlabel("Antenna no.")
        ax.set_ylabel("SNR")
        ax.plot([ int(ant.name) for ant in antennas], beacon_snrs, 'o', ls='none')

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".beacon_vs_noise_snr.pdf"))

    ##
    ## Beacon vs Total SNR
    ##
    if True:
        beacon_snrs = [ lib.signal_to_noise(myfilter(beacon_amp*ant.beacon), ant.E_AxB, samplerate=1/dt, signal_band=beacon_pb, noise_band=pb, mode='sine') for ant in antennas ]

        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title("Maximum Beacon/Total SNR")
        ax.set_xlabel("Antenna no.")
        ax.set_ylabel("SNR")
        ax.plot([ int(ant.name) for ant in antennas], beacon_snrs, 'o', ls='none')

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".beacon_vs_total_snr.pdf"))

    ##
    ## Debug plot of Signal vs Noise SNR
    ##
    if True:
        ant = antennas[0]

        fig, ax = plt.subplots(figsize=figsize)
        _debug_snrs = lib.signal_to_noise(myfilter(ant.orig_E_AxB), myfilter(ant.noise), samplerate=1/dt, debug_ax=ax, mode='pulse')

        ax.legend(title="$\\langle SNR \\rangle$ = {: .1e}".format(np.mean(_debug_snrs)))

        ax.set_title("Signal (max amp) and Noise (rms)")
        ax.set_xlabel("Samples")
        ax.set_ylabel("Amplitude")

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".airshower_vs_noise_snr.debug_plot.pdf"))

    ##
    ## Signal vs Noise SNR
    ##
    if True:
        airshower_snrs = [ lib.signal_to_noise(myfilter(ant.orig_E_AxB), myfilter(ant.noise), samplerate=1/dt, mode='pulse') for ant in antennas ]

        # write mean and std to file
        beacon.write_snr_file(airshower_snr_fname, airshower_snrs)

        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title("Maximum Airshower/Noise SNR")
        ax.set_xlabel("Antenna no.")
        ax.set_ylabel("SNR")
        ax.plot([ int(ant.name) for ant in antennas], airshower_snrs, 'o', ls='none')

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".airshower_vs_noise_snr.pdf"))

    ##
    ## Debug plot of Signal vs Beacon SNR
    ##
    if True:
        ant = antennas[0]

        fig, ax = plt.subplots(figsize=figsize)
        if False: #indirect SNR max_amp(signal) vs max_amp(beacon)
            _debug_snrs_E_AxB = lib.signal_to_noise(myfilter(ant.orig_E_AxB), myfilter(ant.noise), samplerate=1/dt, debug_ax=ax, mode='pulse')
            _debug_snrs_sine = lib.signal_to_noise(myfilter(beacon_amp*ant.beacon), myfilter(ant.noise), samplerate=1/dt, debug_ax=ax, mode='pulse')

            _debug_snrs = _debug_snrs_E_AxB / _debug_snrs_sine
        else: # direct max_amp(signal) vs rms(beacon)
            _debug_snrs = lib.signal_to_noise(myfilter(ant.orig_E_AxB), myfilter(beacon_amp*ant.beacon), samplerate=1/dt, debug_ax=ax, mode='pulse')

        ax.legend(title="$\\langle SNR \\rangle$ = {: .1e}".format(np.mean(_debug_snrs)))

        ax.set_title("Signal (max amp), Beacon (max amp) and Noise (rms)")
        ax.set_xlabel("Samples")
        ax.set_ylabel("Amplitude")

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".airshower_vs_beacon_snr.debug_plot.pdf"))


    ##
    ## Signal vs Beacon SNR
    ##
    if True:
        shower_beacon_snrs = [ lib.signal_to_noise(myfilter(ant.orig_E_AxB), myfilter(beacon_amp*ant.beacon), samplerate=1/dt, mode='pulse') for ant in antennas ]

        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title("Maximum Airshower/Beacon RMS SNR")
        ax.set_xlabel("Antenna no.")
        ax.set_ylabel("SNR")
        ax.plot([ int(ant.name) for ant in antennas], beacon_snrs, 'o', ls='none')

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".airshower_vs_beacon_snr.pdf"))



    ##
    ## Total signal vs Noise SNR
    ##
    if True:
        shower_snrs = [ lib.signal_to_noise(ant.E_AxB, myfilter(ant.noise), samplerate=1/dt, mode='pulse') for ant in antennas ]

        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title("Total (Signal+Beacon+Noise)/Noise SNR")
        ax.set_xlabel("Antenna no.")
        ax.set_ylabel("SNR")
        ax.plot([ int(ant.name) for ant in antennas], shower_snrs, 'o', ls='none')

        if fig_dir:
            fig.savefig(path.join(fig_dir, path.basename(__file__) + f".total_snr.pdf"))

    if show_plots:
        plt.show()