import numpy as np
from functools import partial

def distance(x1, x2):
    """
    Calculate the Euclidean distance between two locations x1 and x2
    """
    return np.sqrt( np.sum( (x1 - x2)**2, axis=-1) )

def plot_geometry(ax, emitters=[], antennae=[], unit='m'):
    """
    Show the geometry of emitters and antennae in a square plot.

    Parameters
    ----------
    ax - matplotlib.Axes
        The axis object to plot the geometry on.
    emitters - list of Locations
        The Emitter objects to plot.
    antennae - list of Locations
        The Receiver objects to plot.

    Returns
    -------
    ax - matplotlib.Axes
        The axis object containing the plotted geometry.
    annots - dict of list of matplotlib.text.Annotation
        The dictionary is split up into a list of annotations
        belonging to the emitters, and one for the antennae.
    """

    ax.grid()
    ax.set_title("Geometry of Emitter(s) and Antennae")
    ax.set_ylabel("y ({})".format(unit))
    ax.set_xlabel("x ({})".format(unit))
    ax.margins(0.3)
    ax.set_aspect('equal', 'datalim') # make it a square plot

    annots = {}
    for k, locs in {"E": emitters, "A": antennae}.items():
        if k == "E":
            marker='*'
            prefix = k
        elif k == "A":
            marker="o"
            prefix = k

        # create the list of annotations
        if k not in annots:
            annots[k] = []

        # plot marker and create annotation
        for j, loc in enumerate(locs):
            label = "{}{}".format(prefix, j)
            ax.plot(*loc.x, marker=marker, label=label)
            annots[k].append(ax.annotate(label, loc.x))

    return ax, annots

class Location:
    """
    A location is a point designated by a spatial coordinate x.

    Locations are wrappers around a Numpy N-dimensional array.
    """

    def __init__(self, x):
        self.x = np.asarray(x)

    def __repr__(self):
        return "Location({})".format(repr(self.x))

    def __getitem__(self, key):
        return self.x[key]

    def __setitem__(self, key, val):
        self.x[key] = val

    def distance(self, other):
        if isinstance(other, Location):
            other = other.x

        return distance(self.x, other)

    # math
    def __add__(self, other):
        if isinstance(other, Location):
            other = other.x

        return self.__class__(self.x + other)

    def __sub__(self, other):
        if isinstance(other, Location):
            other = other.x

        return self.__class__(self.x - other)

    def __mul__(self, other):
        return self.__class__(self.x * other)

    def __eq__(self, other):
        if isinstance(other, Location):
            other = other.x

        return np.all(self.x == other)

    # math alias functions
    __radd__ = __add__
    __rsub__ = __sub__
    __rmul__ = __mul__