#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import numpy.fft as ft

import aa_generate_beacon as beacon
from view_orig_ant0 import plot_antenna_geometry
import lib
from earsim import Antenna


if __name__ == "__main__":
    from os import path
    import sys
    import matplotlib
    import os
    if os.name == 'posix' and "DISPLAY" not in os.environ:
        matplotlib.use('Agg')

    from scriptlib import MyArgumentParser
    parser = MyArgumentParser()
    parser.add_argument('ant_idx', default=[72], nargs='*', type=int, help='Antenna Indices')
    parser.add_argument('-p', '--polarisations', choices=['x', 'y', 'z', 'b', 'AxB', 'n', 'b+n'], action='append', help='Default: x,y,z')
    parser.add_argument('--geom', action='store_true', help='Make a figure containg the geometry from tx to antenna(s)')
    parser.add_argument('--ft', action='store_true', help='Add FT strenghts of antenna traces')

    args = parser.parse_args()

    figsize = (12,8)

    plot_ft_amplitude = args.ft
    plot_geometry = args.geom

    fig_dir = args.fig_dir
    show_plots = args.show_plots

    if not args.polarisations:
        args.polarisations = ['x','y', 'z']

    ####
    fname_dir = args.data_dir
    antennas_fname = path.join(fname_dir, beacon.antennas_fname)
    tx_fname = path.join(fname_dir, beacon.tx_fname)

    f_beacon, tx, antennas = beacon.read_beacon_hdf5(antennas_fname)
    _, __, txdata = beacon.read_tx_file(tx_fname)
    beacon_amp = np.max(txdata['amplitudes'])# mu V/m

    idx = args.ant_idx

    if not idx:
        if not True:
            idx = [0, 1, len(antennas)//2, len(antennas)//2+1, -2, -1]
        elif not True:
            idx = np.arange(1, 20, 2, dtype=int)
        elif True:
            # center 6 antennas
            names = [55, 56, 57, 65, 66, 45, 46]

            idx = [ i for i, ant in enumerate(antennas) if int(ant.name) in names ]

    for i_fig in range(2):
        name_dist=''

        if i_fig == 1: #read in the raw_traces
            _, __, antennas = beacon.read_beacon_hdf5(antennas_fname, traces_key='prefiltered_traces')
            name_dist='.raw'


        fig1, axs = plt.subplots(1+plot_ft_amplitude*1 +0*1, figsize=figsize)
        if not plot_ft_amplitude:
            axs = [axs]
        axs[0].set_xlabel('t [ns]')
        axs[0].set_ylabel('[$\mu$V/m]')

        if i_fig == 1:
            axs[0].set_title("UnFiltered traces")
        else:
            axs[0].set_title("Filtered traces")

        if True:
            axs[0].set_xlim(-250, 250)

        if plot_ft_amplitude:
            axs[1].set_xlabel('f [GHz]')
            axs[1].set_ylabel('Power')

            if len(axs) > 2:
                axs[2].set_ylabel("Phase")
                axs[2].set_xlabel('f [GHz]')
                axs[2].set_ylim(-np.pi,+np.pi)

        colorlist = []
        for i in idx:
            ant = antennas[i]

            n_samples = len(ant.t)
            samplerate = (ant.t[-1] - ant.t[0])/n_samples

            axs[0].axvline(ant.t[0], color='k', alpha=0.5)

            mydict = {}
            for p in args.polarisations:
                pattr = 'E'+str(p)
                if p == 'b':
                    pattr = 'beacon'
                elif p == 'n':
                    pattr = 'noise'
                elif p == 'AxB':
                    pattr = 'E_AxB'
                elif p =='b+n':
                    mydict[p] = getattr(ant,'noise') + beacon_amp*getattr(ant, 'beacon')
                    continue

                mydict[p] = getattr(ant, pattr)

            if 'b' in mydict:
                mydict['b'] *= beacon_amp

            for j, (direction, trace) in enumerate(mydict.items()):
                l = axs[0].plot(ant.t, trace, label=f"$E_{{{direction}}}$ {ant.name}", alpha=0.7)

                #if False and j == 0 and 't0' in ant.attrs:
                #    axs[0].axvline(ant.attrs['t0'], color=l[0].get_color(), alpha=0.5)

                colorlist.append(l[0].get_color())

                if not plot_ft_amplitude:
                    continue

                fft, freqs = lib.get_freq_spec(trace, 1/samplerate)

                axs[1].plot(freqs, np.abs(fft)**2, color=l[0].get_color())

                if True:
                    cft = lib.direct_fourier_transform(f_beacon, ant.t, trace)
                    amp = (cft[0]**2 + cft[1]**2)

                    #axs[0].axhline(amp, color=l[0].get_color())

                    print(amp)
                    phase = np.arctan2(cft[0],cft[1])
                    axs[1].plot(f_beacon, amp, color=l[0].get_color(), marker='3', alpha=0.8, ms=30)
                    if len(axs) > 2:
                        axs[2].plot(f_beacon, phase, color=l[0].get_color(), marker='3', alpha=0.8, ms=30)

        if plot_ft_amplitude:
            fig1.legend(loc='center right', ncol=min(2, len(idx)))
        else:
            axs[0].legend(loc='upper right', ncol=min(3, len(idx)))

        # Keep trace plot symmetric around 0
        max_lim = max(np.abs(axs[0].get_ylim()))
        axs[0].set_ylim(-max_lim, max_lim)

        # Keep spectrum between 0 and 100 MHz
        if len(axs) > 1:
            xlims = axs[1].get_xlim()
            axs[1].set_xlim(max(0, xlims[0]), min(0.1, xlims[1]))
            if False: # extra zoom
                axs[1].set_xlim(f_beacon - 0.01, f_beacon + 0.01)

        if fig_dir:
            fig1.savefig(path.join(fig_dir, path.basename(__file__) + f".trace{name_dist}.pdf"))

    if plot_geometry:
        if len(mydict) == 1:
            geom_colorlist = colorlist
        else:
            # only take the colour belonging to mydict[0]
            geom_colorlist = [ colorlist[len(mydict)*(i)] for i in range(len(colorlist)//len(mydict)) ]

        fig2, axs2 = plt.subplots(1, figsize=figsize)
        plot_antenna_geometry(antennas, ax=axs2, plot_max_values=False, color='grey', plot_names=False)
        plot_antenna_geometry([ antennas[i] for i in idx], ax=axs2, colors=geom_colorlist, plot_max_values=False)

        axs2.plot(tx.x, tx.y, marker='X', color='k')
        axs2.set_title("Geometry with selected antennas")
        if fig_dir:
            fig2.savefig(path.join(fig_dir, path.basename(__file__) + f".geom.pdf"))

    if show_plots:
        plt.show()