"""
Routines to assist in plotting
"""

import matplotlib.pyplot as plt
import numpy as np

def annotate_width(ax, name, x1, x2, y, text_kw={}, arrow_kw={}):
    default_arrow_kw = dict(
            xy = (x1, y),
            xytext = (x2,y),
            arrowprops = dict(
                arrowstyle="<->",
                shrinkA=False,
                shrinkB=False
            ),
    )

    default_text_kw = dict(
            va='bottom',
            ha='center',
            xy=((x1+x2)/2, y)
    )

    an1 = ax.annotate("", **{**default_arrow_kw, **arrow_kw})
    an2 = ax.annotate(name, **{**default_text_kw, **text_kw})

    return [an1, an2]


def beacon_sync_figure(
    time, impulses, beacons,
    delta_t=0,
    beacon_offsets=[],
    impulse_offsets=[],
    f_beacon=1,
    colors=['y','g'],
    show_annotations=False,
    multiplier_name = ['m','n'],
    fig_kwargs = {'figsize': (12,4)},
    ns=1e-3
):
    if not hasattr(delta_t, "__len__"):
        delta_t = np.array([0, delta_t])

    if not hasattr(impulse_offsets, "__len__"):
        impulse_offsets = np.repeat(impulse_offsets, 2)

    if not hasattr(beacon_offsets, "__len__"):
        beacon_offsets = np.repeat(beacon_offsets, 2)

    N_axes = 2
    if show_annotations:
        N_axes += 1

    fig, axes = plt.subplots(N_axes,1, sharex=True, **fig_kwargs)
    axes[-1].set_xlabel("Time [ns]")
    for i in range(0, 2):
        axes[i].set_yticks([],[])
        axes[i].set_ylabel("Antenna {:d}".format(i+1))
        axes[i].plot((time-delta_t[i])/ns, impulses[i])
        axes[i].plot((time-delta_t[i])/ns, beacons[i], marker='.')


    # indicate timing of pulses
    for i, impulse_offset in enumerate(impulse_offsets):
        kwargs = dict(color=colors[i])

        axes_list = [axes[i]]
        if show_annotations:
            axes_list.append(axes[-1])

        [ax.axvline((impulse_offset-delta_t[i])/ns, **kwargs) for ax in axes_list]


    # indicate timing of beacon
    for i, beacon_offset in enumerate(beacon_offsets):
        kwargs = dict(color=colors[i], ls=(0, (3,2)))
        tick_kwargs = dict(color='k', alpha=0.2)

        axes_list = [axes[i]]
        if show_annotations:
            axes_list.append(axes[-1])


        # indicate every period of the beacon
        beacon_ticks =  beacon_offset + [(n)*1/f_beacon for n in range(1+int((time[-1] - time[0]) * f_beacon))]

        [axes[i].axvline((tick-delta_t[i])/ns, **{**kwargs, **tick_kwargs}) for tick in beacon_ticks]

        # reference period in beacon
        # is the first tick > 0
        ref_tick = beacon_ticks[0]
        [ax.axvline((ref_tick-delta_t[i])/ns, **kwargs) for ax in axes_list]

        if show_annotations:
            # annotate width between impulse and closest beacon tick
            # and closest beacon tick and reference tick
            closest_beacon_tick_id = np.argmin(np.abs(beacon_ticks-impulse_offsets[i]))
            if closest_beacon_tick_id != 0 and beacon_ticks[closest_beacon_tick_id] > impulse_offsets[i]:
                closest_beacon_tick_id -= 1
            closest_beacon_tick = beacon_ticks[closest_beacon_tick_id]

            annotate_width(axes[i], f"$A_{i+1}$", (closest_beacon_tick - delta_t[i])/ns, (impulse_offsets[i]-delta_t[i])/ns, 0.7)
            annotate_width(axes[i], f"$B_{i+1}={multiplier_name[i]}T$", (closest_beacon_tick-delta_t[i])/ns, (ref_tick-delta_t[i])/ns, 0.4)


    if show_annotations:
        axes[-1].set_yticks([],[])

        # annotate width between beacon reference periods
        annotate_width(axes[-1], "$t_\phi$", (beacon_offsets[0]-delta_t[0])/ns, (beacon_offsets[-1]-delta_t[-1])/ns, 0.4)

        # annotate width between pulses
        annotate_width(axes[-1], "$\Delta t$", (impulse_offsets[0]-delta_t[0])/ns, (impulse_offsets[-1]-delta_t[-1])/ns, 0.4)

    return fig, axes