#!/usr/bin/env python3
# vim: fdm=indent ts=4

"""
Find beacon phases in antenna traces
And save these to a file
"""

import h5py
import matplotlib.pyplot as plt
import numpy as np

import aa_generate_beacon as beacon
import lib

if __name__ == "__main__":
    from os import path
    import sys
    import matplotlib
    import os
    if os.name == 'posix' and "DISPLAY" not in os.environ:
        matplotlib.use('Agg')

    f_beacon_band = (49e-3,55e-3) #GHz
    allow_frequency_fitting = False
    read_frequency_from_file = True

    use_AxB_trace = True if len(sys.argv) < 2 else bool(int(sys.argv[1]))
    use_beacon_trace = True # only applicable if AxB = False

    show_plots = True

    fname = "ZH_airshower/mysim.sry"

    print("use_AxB_trace:", use_AxB_trace, "use_beacon_trace:",use_beacon_trace)

    ####
    fname_dir = path.dirname(fname)
    antennas_fname = path.join(fname_dir, beacon.antennas_fname)

    fig_dir = "./figures" # set None to disable saving

    if not path.isfile(antennas_fname):
        print("Antenna file cannot be found, did you try generating a beacon?")
        sys.exit(1)

    # read in antennas
    with h5py.File(antennas_fname, 'a') as fp:
        if 'antennas' not in fp.keys():
            print("Antenna file corrupted? no antennas")
            sys.exit(1)

        group = fp['antennas']

        f_beacon = None
        if read_frequency_from_file and 'tx' in fp:
            tx = fp['tx']
            if 'f_beacon' in tx.attrs:
                f_beacon = tx.attrs['f_beacon']
            else:
                print("No frequency found in file.")
                sys.exit(2)
            f_beacon_estimate_band = 0.01*f_beacon

        elif allow_frequency_fitting:
            f_beacon_estimate_band = (f_beacon_band[1] - f_beacon_band[0])/2
            f_beacon = f_beacon_band[1] - f_beacon_estimate_band
        else:
            print("Not allowed to fit frequency and no tx group found in file.")
            sys.exit(2)

        N_antennas = len(group.keys())
        # just for funzies
        found_data = np.zeros((N_antennas, 3))

        # Determine frequency and phase
        for i, name in enumerate(group.keys()):
            h5ant = group[name]

            # use E_AxB only instead of polarisations
            if use_AxB_trace:
                if 'E_AxB' not in h5ant.keys():
                    print(f"Antenna does not have 'E_AxB' in {name}")
                    sys.exit(1)

                traces = h5ant['E_AxB']

                t_trace = traces[0]
                test_traces = [ traces[1] ]
                orients = ['E_AxB']

                # TODO: refine masking
                # use beacon but remove where E_AxB-Beacon != 0
                if True:
                    if not True:
                        t_mask = np.isclose(h5ant['E_AxB'][1], h5ant['traces'][4], rtol=1e-3, atol=1e-3)
                    else:
                        t_mask = np.ones(len(t_trace), dtype=bool)
                        t_mask[1500:3000] = False # magic numbers from aa_generate_beacon

                    t_trace = t_trace[t_mask]
                    for j, t in enumerate(test_traces):
                        test_traces[j] = t[t_mask]
                        orients[j] = orients[j] + ' masked'

            # use separate polarisations
            else:
                if 'traces' not in h5ant.keys():
                    print(f"Antenna file corrupted? no 'traces' in {name}")
                    sys.exit(1)

                traces = h5ant['traces']
                t_trace = traces[0]

                if use_beacon_trace:
                    # only take the Beacon trace
                    test_traces = [traces[4]]
                    orients = ['B']
                else:
                    test_traces = traces[1:]
                    orients = ['Ex', 'Ey', 'Ez', 'B']

            # modify the length of the traces
            if False:
                t_trace = t_trace[:len(t_trace)//2]
                half_traces = []
                for trace in test_traces:
                    half_traces.append( trace[:len(trace)//2])
                test_traces = half_traces

            # Do Fourier Transforms
            # to find phases and amplitudes
            if True:
                freqs, phases, amps = lib.find_beacon_in_traces(
                        test_traces, t_trace,
                        f_beacon_estimate=f_beacon,
                        frequency_fit=allow_frequency_fitting,
                        f_beacon_estimate_band=f_beacon_estimate_band
                        )
            else:
                # Testing
                freqs = [f_beacon]
                t0 = h5ant.attrs['t0']

                phases = [ 2*np.pi*t0*f_beacon ]
                amps = [ 3e-7 ]

            # choose highest amp
            idx = 0
            if False and len(phases) > 1:
                #idx = np.argmax(amplitudes, axis=-1)
                raise NotImplementedError

            frequency = freqs[idx]
            phase = phases[idx]
            amplitude = amps[idx]
            orientation = orients[idx]

            # Correct for phase by t_trace[0]
            corr_phase = lib.phase_mod(2*np.pi*f_beacon*t_trace[0])
            if False:
                # Subtract phase due to not starting at t=0
                # This is already done in beacon_find_traces
                phase = lib.phase_mod(phase + corr_phase)

            # for reporting using plots
            found_data[i] = frequency, phase, amplitude

            if (show_plots or fig_dir) and (i == 0 or i == 72 or i == 70):
                p2t = lambda phase: phase/(2*np.pi*f_beacon)

                fig, ax = plt.subplots()
                ax.set_title(f"Beacon at antenna {h5ant.attrs['name']}\nF:{frequency:.2e}, P:{phase:.4f}, A:{amplitude:.1e}")
                ax.set_xlabel("t [ns]")
                ax.set_ylabel("Amplitude")

                if True:
                    # let the trace start at t=0
                    t_0 = min(t_trace)
                    extra_phase = corr_phase
                else:
                    t_0 = 0
                    extra_phase = -1*corr_phase

                for j, trace in enumerate(test_traces):
                    ax.plot(t_trace - t_0, test_traces[j], marker='.', label='trace '+orients[j])

                myt = np.linspace(min(t_trace), max(t_trace), 10*len(t_trace)) - t_0
                ax.plot(myt, lib.sine_beacon(frequency, myt, amplitude=amplitude, t0=0, phase=phase+extra_phase), ls='dotted', label='simulated beacon')

                ax.axvline( p2t(lib.phase_mod(-1*(phase+extra_phase), low=0)), color='r', ls='dashed', label='$t_\\varphi$')

                ax.axvline(0,color='grey',alpha=0.5)
                ax.axhline(0,color='grey',alpha=0.5)

                ax.legend()

                if fig_dir:
                    old_xlims = ax.get_xlim()
                    ax.set_xlim(min(t_trace)-t_0-10,min(t_trace)-t_0+40)

                    fig.savefig(path.join(fig_dir, __file__ + f".A{h5ant.attrs['name']}.zoomed.pdf"))

                    ax.set_xlim(*old_xlims)
                    fig.savefig(path.join(fig_dir, __file__ + f".A{h5ant.attrs['name']}.pdf"))

            # save to file
            h5beacon_info = h5ant.require_group('beacon_info')

            # only take n_sig significant digits into account
            # for naming in hdf5 file
            n_sig = 3
            decimal = int(np.floor(np.log10(abs(frequency))))
            freq_name = str(np.around(frequency, n_sig-decimal))

            # delete previous values
            if freq_name in h5beacon_info:
                del h5beacon_info[freq_name]

            h5beacon_freq_info = h5beacon_info.create_group(freq_name)

            h5attrs = h5beacon_freq_info.attrs
            h5attrs['freq'] = frequency
            h5attrs['phase'] = phase
            h5attrs['amplitude'] = amplitude
            h5attrs['orientation'] = orientation

    print("Beacon Phases, Amplitudes and Frequencies written to", antennas_fname)

    # show histogram of found frequencies
    if show_plots or fig_dir:
        if True or allow_frequency_fitting:
            fig, ax = plt.subplots()
            ax.set_xlabel("Frequency")
            ax.set_ylabel("Counts")
            ax.axvline(f_beacon, ls='dashed', color='g')
            ax.hist(found_data[:,0], bins='sqrt', density=False)
            if fig_dir:
                fig.savefig(path.join(fig_dir, __file__ + f".hist_freq.pdf"))

        if True:
            fig, ax = plt.subplots()
            ax.set_xlabel("Amplitudes")
            ax.set_ylabel("Counts")
            ax.hist(found_data[:,2], bins='sqrt', density=False)
            if fig_dir:
                fig.savefig(path.join(fig_dir, __file__ + f".hist_amp.pdf"))

    if show_plots:
        plt.show()