#!/usr/bin/env python3
# vim: fdm=indent ts=4

import h5py
from itertools import combinations, product
import matplotlib.pyplot as plt
import numpy as np

import aa_generate_beacon as beacon
import lib


if __name__ == "__main__":
    from os import path
    import sys

    import os
    import matplotlib
    if os.name == 'posix' and "DISPLAY" not in os.environ:
        matplotlib.use('Agg')

    from scriptlib import MyArgumentParser
    parser = MyArgumentParser()
    args = parser.parse_args()

    fname = "ZH_airshower/mysim.sry"
    c_light = 3e8*1e-9
    show_plots = args.show_plots
    ref_ant_id = None if True else [50] # leave None for all baselines

    ####
    fname_dir = path.dirname(fname)
    antennas_fname = path.join(fname_dir, beacon.antennas_fname)
    time_diffs_fname = 'time_diffs.hdf5' if False else antennas_fname

    fig_dir = args.fig_dir # set None to disable saving

    if not path.isfile(antennas_fname):
        print("Antenna file cannot be found, did you try generating a beacon?")
        sys.exit(1)

    # Read in antennas from file
    f_beacon, tx, antennas = beacon.read_beacon_hdf5(antennas_fname)

    # run over all baselines
    if ref_ant_id is None:
        print("Doing all baselines")
        baselines = list(combinations(antennas,2))
    # use ref_ant
    else:
        ref_ants = [antennas[i] for i in ref_ant_id]
        print("Doing all baselines with {}".format([int(a.name) for a in ref_ants]))
        baselines = list(product(ref_ants, antennas))

    # For now, only one beacon_frequency is supported
    freq_names = antennas[0].beacon_info.keys()
    if len(freq_names) > 1:
        raise NotImplementedError

    freq_name = next(iter(freq_names))

    # Get phase difference per baseline
    phase_diffs = np.empty( (len(baselines), 2) )
    for i, base in enumerate(baselines):
        if i%1000==0:
            print(i, "out of", len(baselines))

        # read f_beacon from the first antenna
        f_beacon = base[0].beacon_info[freq_name]['freq']

        # Get true phase diffs
        try:
            true_phases = np.array([ant.beacon_info[freq_name]['true_phase'] for ant in base])
            true_phases_diff = lib.phase_mod(lib.phase_mod(true_phases[1]) - lib.phase_mod(true_phases[0]))
        except IndexError:
            # true_phase not determined yet
            print(f"Missing true_phases for {freq_name} in baseline {base[0].name},{base[1].name}")
            true_phases_diff = np.nan

        # save phase difference with antenna names
        phase_diffs[i] = [f_beacon, true_phases_diff]

    beacon.write_baseline_time_diffs_hdf5(time_diffs_fname, baselines, phase_diffs[:,1], [0]*len(phase_diffs), phase_diffs[:,0])

    ##############################
    # Compare actual time shifts #
    ##############################
    actual_antenna_true_phases = { a.name: -2*np.pi*a.attrs['clock_offset']*f_beacon for a in sorted(antennas, key=lambda a: int(a.name)) }

    # Compare actual time shifts
    my_phase_diffs = []
    for i,b in enumerate(baselines):
        actual_true_phase_diff = lib.phase_mod( lib.phase_mod(actual_antenna_true_phases[b[1].name]) - lib.phase_mod(actual_antenna_true_phases[b[0].name]))

        this_actual_true_phase_diff = lib.phase_mod( actual_true_phase_diff )
        my_phase_diffs.append(this_actual_true_phase_diff)

    # Make a plot
    if True:
        N_base = len(baselines)
        N_ant = len(antennas)

        for i in range(2):
            plot_residuals = i == 1
            colors = ['blue', 'orange']

            fig, axs = plt.subplots(2, 1, sharex=True)

            if True:
                forward = lambda x: x/(2*np.pi*f_beacon)
                inverse = lambda x: 2*np.pi*x*f_beacon
                secax = axs[0].secondary_xaxis('top', functions=(forward, inverse))
                secax.set_xlabel('Time $\\Delta\\varphi/(2\\pi f_{beac})$ [ns]')

            if plot_residuals:
                phase_residuals = lib.phase_mod(phase_diffs[:,1] - my_phase_diffs)

                fig.suptitle("Difference between Measured and Actual phase difference\n for Baselines (i,j" + (')' if ref_ant_id is None else '='+str([ int(a.name) for a in ref_ants])+')'))
                axs[-1].set_xlabel("Baseline Phase Residual $\\Delta\\varphi_{ij_{meas}} - \\Delta\\varphi_{ij_{true}}$ [rad]")
            else:
                fig.suptitle("Comparison Measured and Actual phase difference\n for Baselines (i,j"  + (')' if ref_ant_id is None else '='+str([ int(a.name) for a in ref_ants])+')'))
                axs[-1].set_xlabel("Baseline Phase $\\Delta\\varphi_{ij}$ [rad]")


            i=0
            axs[i].set_ylabel("#")
            if plot_residuals:
                axs[i].hist(phase_residuals, bins='sqrt', density=False, alpha=0.8, color=colors[0])
            else:
                axs[i].hist(phase_diffs[:,1], bins='sqrt', density=False, alpha=0.8, color=colors[0], ls='solid' , histtype='step', label='Measured')
                axs[i].hist(my_phase_diffs,   bins='sqrt', density=False, alpha=0.8, color=colors[1], ls='dashed', histtype='step', label='Actual')


            i=1
            axs[i].set_ylabel("Baseline no.")
            if plot_residuals:
                axs[i].plot(phase_residuals, np.arange(N_base), alpha=0.6, ls='none', marker='x', color=colors[0])
            else:
                axs[i].plot(phase_diffs[:,1], np.arange(N_base), alpha=0.8, color=colors[0], ls='none', marker='x', label='calculated')
                axs[i].plot(my_phase_diffs,   np.arange(N_base), alpha=0.8, color=colors[1], ls='none', marker='+', label='actual time shifts')

                axs[i].legend()
            fig.tight_layout()

            if fig_dir:
                extra_name = "measured"
                if plot_residuals:
                    extra_name = "residuals"
                fig.savefig(path.join(fig_dir, path.basename(__file__) + f".{extra_name}.F{freq_name}.pdf"))

    if show_plots:
        plt.show()