#!/usr/bin/env python3
# vim: fdm=indent ts=4

import h5py
from itertools import combinations, zip_longest
import matplotlib.pyplot as plt
import numpy as np

import aa_generate_beacon as beacon
import lib

if __name__ == "__main__":
    from os import path
    import sys

    fname = "ZH_airshower/mysim.sry"

    ####
    fname_dir = path.dirname(fname)
    antennas_fname = path.join(fname_dir, beacon.antennas_fname)

    if not path.isfile(antennas_fname):
        print("Antenna file cannot be found, did you try generating a beacon?")
        sys.exit(1)

    # Read in antennas from file
    f_beacon, tx, antennas = beacon.read_beacon_hdf5(antennas_fname)

    if True and 'beacon_phase_true' in antennas[0].attrs:
        true_phases = np.array([a.attrs['beacon_phase_true'] for a in antennas])
    else:
        true_phases = np.empty( (len(antennas)) )

        for i, ant in enumerate(antennas):
            measured_phase = ant.attrs['beacon_phase_measured']

            geom_time = lib.geometry_time(tx, ant, c_light=3e8*1e-9)
            geom_phase = geom_time * 2*np.pi*f_beacon

            true_phases[i] = lib.phase_mod(measured_phase) - lib.phase_mod(geom_phase)

            ant.attrs['beacon_phase_true'] = true_phases[i]

    # Plot True Phases
    if True:
        fig, ax = plt.subplots()
        spatial_unit=None
        fig.suptitle('f= {:2.0f}MHz'.format(f_beacon*1e3))

        antenna_locs = list(zip(*[(ant.x, ant.y) for ant in antennas]))
        ax.set_xlabel('x' if spatial_unit is None else 'x [{}]'.format(spatial_unit))
        ax.set_ylabel('y' if spatial_unit is None else 'y [{}]'.format(spatial_unit))
        scatter_kwargs = {}
        scatter_kwargs['cmap'] = 'Spectral_r'
        scatter_kwargs['vmin'] = -np.pi
        scatter_kwargs['vmax'] = +np.pi
        color_label='$\\varphi$'

        sc = ax.scatter(*antenna_locs, c=true_phases, **scatter_kwargs)
        fig.colorbar(sc, ax=ax, label=color_label)

    # run over all baselines
    if True:
        baselines = list(combinations(antennas,2))
    # use ref_ant
    else:
        ref_ant = antennas[0]
        baselines = list(zip_longest([], antennas, fillvalue=ref_ant))

    integer_periods = None
    # read integer ks from file if possible
    # and save beacon_phase_true
    with h5py.File(antennas_fname, 'a') as fp:
        for i, ant in enumerate(antennas):
            name = ant.name
            # set true beacon_phase
            fp['antennas'][name].attrs['beacon_phase_true'] = true_phases[i]

        # read integer period from file
        if True and 'beacon_ks' in fp:
            integer_periods = np.array(fp['beacon_ks'])


    # Determine integer multiple of periods to shift
    if integer_periods is None:
        integer_periods = np.empty( (len(baselines), 3) )
        for i, base in enumerate(baselines):
            # Delta between first timestamp from both antennas
            delta_t_a = base[0].t[0] - base[1].t[0]
            # + phase difference
            delta_t_p = np.diff([ant.attrs['beacon_phase_true'] for ant in base])[0]/(2*np.pi*f_beacon)

            sampling_dt = (base[1].t[1] - base[1].t[0])

            print("DT(A,P)", delta_t_a, delta_t_p, 1/f_beacon)

            # which traces to keep track of
            traces = [ base[0].Ex, base[1].Ex ]

            # how many samples to shift
            ks, maxima = lib.coherence_sum_maxima(-1*traces[0], -1*traces[1])
            max_idx = np.argmax(maxima)
            delta_t_c = sampling_dt*ks[max_idx] # ns
            print("K", ks[max_idx], sampling_dt, '=', delta_t_c)

            k, rest = np.divmod(delta_t_c, f_beacon)
            integer_periods[i] = [int(base[0].name), int(base[1].name), k]


            print(k, rest*f_beacon, delta_t_p)

            # Only continue for two random combinations
            if i not in [ 50, 51 ]:
                continue

            fig, ax = plt.subplots()
            ax.set_xlabel("k")
            ax.set_ylabel("Maximum correlation")
            ax.plot(ks, maxima)
            ax.plot(ks[max_idx], maxima[max_idx], marker='X')

            fig, ax = plt.subplots()
            dt = base[1].t[1] - base[1].t[0]
            ax.set_xlabel('t')
            ax.plot(base[0].t, traces[0], label='Reference')
            ax.plot(base[1].t, traces[1], label='Original', alpha=0.4)
            ax.plot(base[1].t + delta_t_a + delta_t_c, traces[1], label='Coherence', alpha=0.6)

            ax.legend()

    # Save integer periods to antennas
    with h5py.File(antennas_fname, 'a') as fp:
        group_name = 'beacon_ks'
        if group_name in fp:
            del fp[group_name]

        fp.create_dataset(group_name, data=integer_periods)

    plt.show()
    # Report back to CLI
    print("Period Multiples resolved in", antennas_fname)