#!/usr/bin/env python3
# vim: fdm=indent ts=4

"""
Find beacon phases in antenna traces
And save these to a file
"""

import h5py
import matplotlib.pyplot as plt
import numpy as np

import aa_generate_beacon as beacon
import lib
from lib import figlib


if __name__ == "__main__":
    from os import path
    import sys
    import matplotlib
    import os
    if os.name == 'posix' and "DISPLAY" not in os.environ:
        matplotlib.use('Agg')

    from scriptlib import MyArgumentParser
    parser = MyArgumentParser()

    group1 = parser.add_mutually_exclusive_group()
    group1.add_argument('--AxB', dest='use_AxB_trace', action='store_true', help='Only use AxB trace, if both AxB and beacon are not used, we use the antenna polarisations.')
    group1.add_argument('--beacon', dest='use_beacon_trace', action='store_true', help='Only use the beacon trace')

    parser.add_argument('--N-mask', type=float, default=500, help='Mask N_MASK samples around the absolute maximum of the trace. (Default: %(default)d)')

    args = parser.parse_args()

    f_beacon_band = (49e-3,55e-3) #GHz
    allow_frequency_fitting = False
    read_frequency_from_file = True
    N_mask = int(args.N_mask)

    use_only_AxB_trace = args.use_AxB_trace
    use_only_beacon_trace = args.use_beacon_trace # only applicable if AxB = False

    show_plots = args.show_plots

    figsize = (12,8)

    print("use_only_AxB_trace:", use_only_AxB_trace, "use_only_beacon_trace:", use_only_beacon_trace)

    ####
    fname_dir = args.data_dir
    antennas_fname = path.join(fname_dir, beacon.antennas_fname)
    beacon_snr_fname = path.join(fname_dir, beacon.beacon_snr_fname)

    fig_dir = args.fig_dir # set None to disable saving

    if not path.isfile(antennas_fname):
        print("Antenna file cannot be found, did you try generating a beacon?")
        sys.exit(1)

    beacon_snrs = beacon.read_snr_file(beacon_snr_fname)
    snr_str = f"$\\langle SNR \\rangle$ = {beacon_snrs['mean']: .1e}"

    # read in antennas
    with h5py.File(antennas_fname, 'a') as fp:
        if 'antennas' not in fp.keys():
            print("Antenna file corrupted? no antennas")
            sys.exit(1)

        group = fp['antennas']

        f_beacon = None
        if read_frequency_from_file and 'tx' in fp:
            tx = fp['tx']
            if 'f_beacon' in tx.attrs:
                f_beacon = tx.attrs['f_beacon']
            else:
                print("No frequency found in file.")
                sys.exit(2)
            f_beacon_estimate_band = 0.01*f_beacon

        elif allow_frequency_fitting:
            f_beacon_estimate_band = (f_beacon_band[1] - f_beacon_band[0])/2
            f_beacon = f_beacon_band[1] - f_beacon_estimate_band
        else:
            print("Not allowed to fit frequency and no tx group found in file.")
            sys.exit(2)

        N_antennas = len(group.keys())
        # just for funzies
        found_data = np.zeros((N_antennas, 3)) # freq, phase, amp
        noise_data = np.zeros((N_antennas, 2)) # phase, amp

        # Determine frequency and phase
        for i, name in enumerate(group.keys()):
            h5ant = group[name]

            # use E_AxB only instead of polarisations
            if use_only_AxB_trace:
                traces_key = 'E_AxB'
                if traces_key not in h5ant.keys():
                    print(f"Antenna does not have '{traces_key}' in {name}")
                    sys.exit(1)

                traces = h5ant[traces_key]
                t_trace = traces[0]
                test_traces = [ traces[1] ]
                orients = ['E_AxB']

            # Only beacon
            elif use_only_beacon_trace:
                traces_key = 'filtered_traces'
                if traces_key not in h5ant.keys():
                    print(f"Antenna file corrupted? no '{traces_key}' in {name}")
                    sys.exit(1)

                traces = h5ant[traces_key]
                t_trace = traces[0]
                test_traces = [traces[4]]
                orients = ['B']

            # use separate polarisations
            else:
                traces_key = 'filtered_traces'
                if traces_key not in h5ant.keys():
                    print(f"Antenna file corrupted? no '{traces_key}' in {name}")
                    sys.exit(1)

                traces = h5ant[traces_key]
                t_trace = traces[0]
                test_traces = [traces[i] for i in range(1,4)]
                orients = ['Ex', 'Ey', 'Ez']

            # Really only select the first component
            if True:
                test_traces = [test_traces[0]]
                orients = [orients[0]]

            # TODO: refine masking
            # use beacon but remove where E_AxB-Beacon != 0
            # Uses the first traces as reference
            t_mask = 0
            if N_mask and orients[0] != 'B':
                N_pre, N_post = N_mask//2, N_mask//2

                max_idx = np.argmax(test_traces[0])

                low_idx = max(0, max_idx-N_pre)
                high_idx = min(len(t_trace), max_idx+N_post)

                t_mask = np.ones(len(t_trace), dtype=bool)
                t_mask[low_idx:high_idx] = False

                t_trace = t_trace[t_mask]
                for j, t in enumerate(test_traces):
                    test_traces[j] = t[t_mask]
                    orients[j] = orients[j] + ' masked'

            # Do Fourier Transforms
            # to find phases and amplitudes
            if True:
                freqs, beacon_phases, amps = lib.find_beacon_in_traces(
                        test_traces, t_trace,
                        f_beacon_estimate=f_beacon,
                        frequency_fit=allow_frequency_fitting,
                        f_beacon_estimate_band=f_beacon_estimate_band
                        )
            else:
                # Testing
                freqs = [f_beacon]
                t0 = h5ant.attrs['t0']

                beacon_phases = [ 2*np.pi*t0*f_beacon ]
                amps = [ 3e-7 ]

            # Also try to find the phase from the noise trace if available
            if len(h5ant[traces_key]) > 4:
                noise_trace = h5ant[traces_key][5]
                if np.any(t_mask): # Mask the same area
                    noise_trace = noise_trace[t_mask]

                real, imag = lib.direct_fourier_transform(f_beacon, t_trace, noise_trace)
                noise_phase = np.arctan2(imag, real)
                noise_amp = (real**2 + imag**2)**0.5

                noise_data[i] = noise_phase, noise_amp

            # choose highest amp
            idx = np.argmax(amps)
            if False and len(beacon_phases) > 1:
                #idx = np.argmax(amplitudes, axis=-1)
                raise NotImplementedError

            frequency = freqs[idx]
            beacon_phase = beacon_phases[idx]
            amplitude = amps[idx]
            orientation = orients[idx]

            # Correct for phase by t_trace[0]
            corr_phase = lib.phase_mod(2*np.pi*f_beacon*t_trace[0])
            if False:
                # Subtract phase due to not starting at t=0
                # This is already done in beacon_find_traces
                beacon_phase = lib.phase_mod(beacon_phase + corr_phase)

            # for reporting using plots
            found_data[i] = frequency, beacon_phase, amplitude

            if (show_plots or fig_dir) and (i == 0 or i == 72):
                p2t = lambda phase: phase/(2*np.pi*f_beacon)

                fig, ax = plt.subplots(figsize=figsize)
                ax.set_title(f"Beacon at antenna {h5ant.attrs['name']}\nF:{frequency:.2e}GHz, $\\varphi$:{beacon_phase:.4f}rad")
                ax.set_xlabel("t [ns]")
                ax.set_ylabel("Amplitude")

                if True:
                    # let the trace start at t=0
                    t_0 = min(t_trace)
                    extra_phase = corr_phase
                else:
                    t_0 = 0
                    extra_phase = -1*corr_phase

                for j, trace in enumerate(test_traces):
                    ax.plot(t_trace - t_0, test_traces[j], marker='.', label='trace '+orients[j])

                myt = np.linspace(min(t_trace), max(t_trace), 10*len(t_trace)) - t_0
                ax.plot(myt, lib.sine_beacon(frequency, myt, amplitude=amplitude, t0=0, phase=beacon_phase+extra_phase), ls='dotted', label='simulated beacon')

                ax.axvline( p2t(lib.phase_mod(-1*(beacon_phase+extra_phase), low=0)), color='r', ls='dashed', label='$t_\\varphi$')

                ax.axvline(0,color='grey',alpha=0.5)
                ax.axhline(0,color='grey',alpha=0.5)

                ax.legend(title=snr_str)

                if fig_dir:
                    old_xlims = ax.get_xlim()
                    ax.set_xlim(min(t_trace)-t_0-10,min(t_trace)-t_0+40)

                    fig.savefig(path.join(fig_dir, path.basename(__file__) + f".A{h5ant.attrs['name']}.zoomed.pdf"))

                    ax.set_xlim(*old_xlims)
                    fig.savefig(path.join(fig_dir, path.basename(__file__) + f".A{h5ant.attrs['name']}.pdf"))

            # save to file
            h5beacon_info = h5ant.require_group('beacon_info')

            # only take n_sig significant digits into account
            # for naming in hdf5 file
            n_sig = 3
            decimal = int(np.floor(np.log10(abs(frequency))))
            freq_name = str(np.around(frequency, n_sig-decimal))

            # delete previous values
            if freq_name in h5beacon_info:
                del h5beacon_info[freq_name]

            h5beacon_freq_info = h5beacon_info.create_group(freq_name)

            h5attrs = h5beacon_freq_info.attrs
            h5attrs['freq'] = frequency
            h5attrs['beacon_phase'] = beacon_phase
            h5attrs['amplitude'] = amplitude
            h5attrs['orientation'] = orientation

            if noise_phase:
                h5attrs['noise_phase'] = noise_phase
                h5attrs['noise_amp'] = noise_amp

    print("Beacon Phases, Amplitudes and Frequencies written to", antennas_fname)

    # show histogram of found frequencies
    if show_plots or fig_dir:

        if True or allow_frequency_fitting:
            fig, ax = plt.subplots(figsize=figsize)
            ax.set_xlabel("Frequency")
            ax.set_ylabel("Counts")
            ax.axvline(f_beacon, ls='dashed', color='g')
            ax.hist(found_data[:,0], bins='sqrt', density=False)
            ax.legend(title=snr_str)
            if fig_dir:
                fig.savefig(path.join(fig_dir, path.basename(__file__) + f".hist_freq.pdf"))

        if True:
            fig, _ = figlib.fitted_histogram_figure(found_data[:,2], fit_distr=['rice'])
            ax = fig.axes[0]
            ax.set_xlabel("Amplitude")
            ax.set_ylabel("Counts")
            ax.hist(found_data[:,2], bins='sqrt', density=False)
            ax.legend(title=snr_str)
            if fig_dir:
                fig.savefig(path.join(fig_dir, path.basename(__file__) + f".hist_amp.pdf"))

        if (noise_data[0] != 0).any():
            if True:
                fig, ax = plt.subplots(figsize=figsize)
                ax.set_title("Noise Phases")
                ax.set_xlabel("Phase")
                ax.set_ylabel("#")
                ax.hist(noise_data[:,0], bins='sqrt', density=False)
                ax.legend(title=snr_str)
                if fig_dir:
                    fig.savefig(path.join(fig_dir, path.basename(__file__) + f".noise.hist_phase.pdf"))

            if True:
                fig, ax = plt.subplots(figsize=figsize)
                ax.set_title("Noise Phase vs Amplitude")
                ax.set_xlabel("Phase")
                ax.set_ylabel("Amplitude [a.u.]")
                ax.plot(noise_data[:,0], noise_data[:,1], ls='none', marker='x')
                if fig_dir:
                    fig.savefig(path.join(fig_dir, path.basename(__file__) + f".noise.phase_vs_amp.pdf"))

            if True:
                fig, _ = figlib.fitted_histogram_figure(noise_data[:,1], fit_distr=['rice', 'rayleigh'])
                ax = fig.axes[0]
                ax.set_title("Noise Amplitudes")
                ax.set_xlabel("Amplitude [a.u.]")
                ax.set_ylabel("#")
                ax.legend(title=snr_str)

                if fig_dir:
                    fig.savefig(path.join(fig_dir, path.basename(__file__) + f".noise.hist_amp.pdf"))

    if show_plots:
        plt.show()