#!/usr/bin/env python3

__doc__ = \
"""
For each antenna i calculate the differences with the other antennas j,
Do these sets of differences match upto an initial difference \Delta_{ii'}?
"""

from itertools import chain, combinations, product
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng()

ns = 1e-9 # s
km = 1e3 # m
c_light = 3e8*ns # m/s

class Antenna:
    """
    Simple Antenna class
    """
    def __init__(self,x=0,y=0,z=0,t0=0,name=""):
        self.x = x
        self.y = y
        self.z = z
        self.t = t0
        self.name = name

    def __repr__(self):
        cls = self.__class__.__name__

        return f'{cls}(x={self.x!r},y={self.y!r},z={self.z!r},t0={self.t!r},name={self.name!r})'

def distance(x1, x2):
    """
    Calculate the Euclidean distance between two locations x1 and x2
    """

    assert type(x1) in [Antenna]
    x1 = np.array([x1.x, x1.y, x1.z])

    assert type(x2) in [Antenna]
    x2 = np.array([x2.x, x2.y, x2.z])

    return np.sqrt( np.sum( (x1-x2)**2 ) )

def geometry_time(dist, x2=None, c_light=c_light):
    if x2 is not None:
        dist = distance(dist, x2)

    return dist/c_light

def phase_mod(phase, low=np.pi):
    """
    Modulo phase such that it falls within the
    interval $[-low, 2\pi - low)$.
    """
    return (phase + low) % (2*np.pi) - low

def antenna_triangles(antennas):
    return combinations(antennas, 3)

def antenna_baselines(antennas):
    return combinations(antennas, 2)

def add_spatial_time_delay(tx, antennas, time=geometry_time, t_scale=1):
    """ Modifies antennas inplace """
    for ant in antennas:
        ant.t += time(tx, ant)/t_scale

def random_antenna(N_ant=1, antenna_ranges=[10e3,10e3,10e3], max_clock_skew=1):
    antennas = []
    for i in range(N_ant):
        loc = antenna_ranges*rng.random(3)
        if max_clock_skew is None:
            t0 = 0
        else:
            t0 =  rng.normal(0, max_clock_skew)

        ant = Antenna(name=i, x=loc[0], y=loc[1], z=loc[1], t0=t0)
        antennas.append(ant)

    return antennas

def single_baseline_referenced_sigmas(tx, baseline, all_antennas, phase_func=None):
    N_ant = len(all_antennas)

    baseline_ts = np.array([b.t for b in baseline])
    baseline_geo = np.array([geometry_time(tx,b) for b in baseline])
    
    not_baseline = lambda ant: ant not in baseline

    sigmas = np.empty( (N_ant-2, 2) )
    for j, ant in enumerate(filter(not_baseline, all_antennas)):
        t_diff = ant.t - baseline_ts
        geo_diff = geometry_time(tx, ant) - baseline_geo
        if phase_func is not None:
            sigmas[i] = phase_func(t_diff - geo_diff)
        else:
            sigmas[i] = t_diff - geo_diff

    return sigmas

def reference_antenna_sigmas(tx, ref_ant, all_antennas, phase_func=None):
    N_ant = len(all_antennas)

    ref_geo = geometry_time(tx, ref_ant)

    sigmas = np.empty( (N_ant) )
    for i, ant in enumerate(all_antennas):
        if False and ant is ref_ant:
            sigmas[i] = 0

        t_diff = ant.t - ref_ant.t
        geo_diff = geometry_time(tx, ant) - ref_geo
        if phase_func is not None:
            sigmas[i] = phase_func(t_diff - geo_diff)
        else:
            sigmas[i] = t_diff - geo_diff

    return sigmas

def all_sigmas_using_reference_antenna(tx, all_antennas, phase_func=None):
    N_ant = len(all_antennas)

    sigmas = np.empty( (N_ant,N_ant) )
    for i, ant in enumerate(all_antennas):
        sigmas[i] = reference_antenna_sigmas(tx, ant, all_antennas, phase_func=phase_func)

    return sigmas

def main(tx, antennas, spatial_unit=None, time_unit=None, ref_idx = [0, 1, -2, -1], plot_phase=False, remove_minimum=True, f_beacon=50e6, scatter_kwargs={}):
    # Use each baseline once as a reference
    # and loop over the remaining antennas
    N_ant = len(antennas)
    fig = None

    default_scatter_kwargs = {}

    #for i, baseline in enumerate(antenna_baselines(antennas)):
    if False:
        baseline = [antennas[0], antennas[1]]
        sigmas = single_baseline_referenced_sigmas(tx, baseline, antennas)
        print("Baseline {},{}".format(baseline[0].name, baseline[1].name))
        print(sigmas)
        print(-1*np.diff(sigmas, axis=1))
        print("Direct", np.diff([a.t for a in baseline]))
        print()

    if True:
        if plot_phase:
            phase_func = lambda t: phase_mod(2*np.pi* f_beacon * t)
            color_label='$\\varphi$'
            default_scatter_kwargs['cmap'] = 'Spectral_r'
            default_scatter_kwargs['vmin'] = -np.pi
            default_scatter_kwargs['vmax'] = +np.pi
        else:
            color_label='t' if time_unit is None else 't ['+time_unit+']'
            phase_func = None

        scatter_kwargs = { **default_scatter_kwargs, **scatter_kwargs }

        sigmas = all_sigmas_using_reference_antenna(tx, antennas, phase_func=phase_func)

        if remove_minimum:
            if True:
                # actually use the time diffs with the first ref ant
                # required for phase alignment
                mins = sigmas[0]
            else:
                mins = -1*np.min(sigmas, axis=-1)

            sigmas = sigmas + mins[:, np.newaxis]

        if plot_phase:
            # Redo the phase mod
            sigmas = phase_mod(sigmas)

        fig, axs = plt.subplots(2,2, sharex=True, sharey=True)
        title = ""
        if remove_minimum:
            title += '$\sigma_{0j}$ added'
        if remove_minimum and plot_phase:
            title += ', '
        if plot_phase:
            t_scaler = 1
            if time_unit == 'ns':
                t_scaler = 1e9
            title += 'f= {:2.0f}MHz'.format(f_beacon*t_scaler/1e6)

        fig.suptitle(title)

        antenna_locs = list(zip(*[(ant.x, ant.y) for ant in antennas]))
        for i, ax in enumerate(axs.flat):
            ax.set_title("Ref Antenna: {}".format(ref_idx[i]))
            ax.set_xlabel('x' if spatial_unit is None else 'x [{}]'.format(spatial_unit))
            ax.set_ylabel('y' if spatial_unit is None else 'y [{}]'.format(spatial_unit))

            sc = ax.scatter(*antenna_locs, c=sigmas[ref_idx[i]], **scatter_kwargs)
            fig.colorbar(sc, ax=ax, label=color_label)
            ax.plot(antennas[ref_idx[i]].x, antennas[ref_idx[i]].y, 'rx')


    return fig, sigmas


if __name__ == "__main__":
    from argparse import ArgumentParser
    from os import path
    rng = np.random.default_rng(1)

    parser = ArgumentParser(description=__doc__)
    parser.add_argument("fname", metavar="path/to/figure[/]", nargs="?", help="Location for generated figure, will append __file__ if a directory. If not supplied, figure is shown.")
    parser.add_argument('num_ant', help='Number of antennas to use', nargs='?', default=5, type=int)
    parser.add_argument('--remove-min', action='store_true')

    command_group = parser.add_mutually_exclusive_group(required=False)
    command_group.add_argument('--time',  help='Calculate times (Default)', action='store_true')
    command_group.add_argument('--phase', help='Calculate wrapped phases', action='store_true')

    args = parser.parse_args()
    args.rm_minimum = True

    args.plot_phase = args.phase
    del args.time, args.phase

    if args.fname == 'none':
        args.fname = None

    if args.fname is not None:
        if path.isdir(args.fname):
            args.fname = path.join(args.fname, path.splitext(path.basename(__file__))[0]) # leave off extension
        if not path.splitext(args.fname)[1]:
            args.fname = [ args.fname+ext for ext in ['.pdf', '.png'] ]

    ######
    antenna_ranges = np.array([10*km,10*km,5*km])
    antenna_max_clock_skew = 100*ns/ns # 0.1 us
    f_beacon = 50e6*ns # 50 MHz

    tx = Antenna(name='tx', x=-300*km, y=200*km, z=0)
    antennas = random_antenna(args.num_ant, antenna_ranges, antenna_max_clock_skew)
    add_spatial_time_delay(tx, antennas)

    fig, sigmas = main(tx, antennas, spatial_unit='m', time_unit='ns', plot_phase=args.plot_phase, remove_minimum=args.rm_minimum, f_beacon=f_beacon)

    ###### Output
    if args.fname is not None:
        if isinstance(args.fname, str):
            args.fname = [args.fname]

        for fname in args.fname:
            plt.savefig(fname)
    else:
        plt.show()