# from wappy import *
from earsim import *
from atmocal import *

import matplotlib.pyplot as plt
from scipy.signal import hilbert
from scipy import signal
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit,minimize
import pandas as pd
import os

plt.rcParams.update({'font.size': 16})

atm = AtmoCal()

from matplotlib import cm

def set_pol_and_bp(e,low=0.03,high=0.08):
    for ant in e.antennas:
        E = [np.dot(e.uAxB,[ex,ey,ez]) for ex,ey,ez in zip(ant.Ex,ant.Ey,ant.Ez)]
        dt =  ant.t[1] -ant.t[0]
        E = block_filter(E,dt,low,high)
        ant.E_AxB = E
        ant.t_AxB = ant.t


def pow_and_time(test_loc,ev,dt=1.0):
    t_ = []
    a_ = []
    t_min = 1e9
    t_max = -1e9
    for ant in ev.antennas:
        #propagate to test location
        aloc = [ant.x,ant.y,ant.z]
        delta,dist = atm.light_travel_time(test_loc,aloc)
        delta = delta*1e9
        t__ = np.subtract(ant.t_AxB,delta)
        t_.append(t__)
        a_.append(ant.E_AxB)
        if t__[0] < t_min:
            t_min = t__[0]
        if t__[-1] > t_max:
            t_max = t__[-1]

    t_sum = np.arange(t_min+1,t_max-1,dt)
    a_sum = np.zeros(len(t_sum))
    #interpolation
    for t_r,E_ in zip (t_,a_):
        f = interp1d(t_r,E_,assume_sorted=True,bounds_error=False,fill_value=0.)
        a_int = f(t_sum)
        a_sum = np.add(a_sum,a_int)
    if len(a_sum) != 0:
        P = np.sum(np.square(np.absolute(np.fft.fft(a_sum))))
    else:
        print("ERROR, a_sum lenght = 0",
              "tmin ",t_min,
              "t_max ",t_max,
              "dt",dt)
        P = 0
    return P,t_,a_,a_sum,t_sum

def shower_axis_slice(e,Xb=200,Xe=1200,dX=2,zgr=0):
    zgr = zgr + e.core[2]
    N = int((Xe-Xb)/dX)
    Xs = np.array(np.linspace(Xb,Xe,N+1))
    ds = np.array([atm.distance_to_slant_depth(np.deg2rad(e.zenith),X,zgr) for X in Xs])

    locs = []
    for d_ in ds:
        xc = np.sin(np.deg2rad(e.zenith))*np.cos(np.deg2rad(e.azimuth))* d_
        yc = np.sin(np.deg2rad(e.zenith))*np.sin(np.deg2rad(e.azimuth))* d_
        zc = np.cos(np.deg2rad(e.zenith))* d_
        locs.append([xc,yc,zc])
    p = []
    for loc in locs:
            P,t_,pulses_,wav,twav = pow_and_time(loc,e)
            p.append(P)
    p = np.asanyarray(p)
    return ds,Xs,locs,p

def shower_plane_slice(e,X=750.,Nx=10,Ny=10,wx=1e3,wy=1e3,xoff=0,yoff=0,zgr=0):
    zgr = zgr + e.core[2]
    dX = atm.distance_to_slant_depth(np.deg2rad(e.zenith),X,zgr)
    x = np.linspace(-wx,wx,Nx)
    y = np.linspace(-wy,wy,Ny)
    xx = []
    yy = []
    p = []
    locs = []
    for x_ in x:
        for y_ in y:
            loc = (x_+xoff)* e.uAxB + (y_+yoff)*e.uAxAxB + dX *e.uA
            locs.append(loc)
            P,t_,pulses_,wav,twav = pow_and_time(loc,e)
            xx.append(x_+xoff)
            yy.append(y_+yoff)
            p.append(P)
    xx = np.asarray(xx)
    yy = np.asarray(yy)
    p = np.asanyarray(p)
    return xx,yy,p,locs[np.argmax(p)]

def slice_figure(e,X,xx,yy,p,mode='horizontal'):
    fig, axs = plt.subplots(1,figsize=(10,8))
    fig.suptitle(r'E = %.1f EeV, $\theta$ = %.1f$^\circ$, $\phi$ = %.1f$^\circ$ X = %.f'%(e.energy,e.zenith,e.azimuth,X))
    sc = axs.scatter(xx/1e3,yy/1e3,c=p,cmap='Spectral_r',alpha=0.6)
    fig.colorbar(sc,ax=axs)
    zgr = 0 + e.core[2]
    dX = atm.distance_to_slant_depth(np.deg2rad(e.zenith),X,zgr)
    xc = np.sin(np.deg2rad(e.zenith))*np.cos(np.deg2rad(e.azimuth))* dX
    yc = np.sin(np.deg2rad(e.zenith))*np.sin(np.deg2rad(e.azimuth))* dX
    if mode == 'horizontal':
        axs.plot(xc/1e3,yc/1e3,'r+',ms=30)
        axs.set_xlabel('x (km)')
        axs.set_ylabel('y (km)')
    elif mode == "sp":
        axs.plot(0,0,'r+',ms=30)
        axs.set_xlabel('-v x B (km)')
        axs.set_ylabel(' vxvxB (km)')
    im = np.argmax(p)
    axs.plot(xx[im]/1e3,yy[im]/1e3,'bx',ms=30)
    fig.tight_layout()
    return fig,axs

def dist_to_line(xp,core,u):
    xp = np.array(xp)
    xp_core = xp-core
    c2 = np.dot(xp_core,xp_core)
    a2 = np.dot((np.dot(xp_core,u)*u),(np.dot(xp_core,u)*u))
    d = (np.abs(c2 - a2))**0.5
    return d

def dist_to_line_sum(param,data,weights):
    #distance line point: a = xp-core is D= | (a)^2-(a dot n)n  |
    #where ux is direction of line and x0 is a point in the line (like t = 0)
    x0 = param[0]
    y0 = param[1]
    theta = param[2]
    phi = param[3]
    core = np.array([x0, y0, 0.])
    u = np.array([np.cos(phi)*np.sin(theta),np.sin(phi)*np.sin(theta),np.cos(theta)])
    dsum = 0
    for xp,w in zip(data,weights):
        dsum += dist_to_line(xp,core,u)*w**2
    # print('%.2e %.2e %.2e %.2e %.2e'%(x0,y0,theta,phi,dsum))
    return dsum/len(data)

def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15):
    axis_points = []
    max_vals = []
    Xsteps = np.linspace(Xlow, Xhigh, N_X)
    zgr=zgr+e.core[2] #not exact
    dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr)
    scale2d = dXref*np.tan(np.deg2rad(2.))
    scale4d = dXref*np.tan(np.deg2rad(4.))
    scale0_2d=dXref*np.tan(np.deg2rad(0.2))
    for X in Xsteps:
        print(X)
        x,y,p,loc_max = shower_plane_slice(e,X,21,21,scale2d,scale2d)
        if savefig:
            fig,axs = slice_figure(e,X,x,y,p,'sp')
            fig.savefig(path+'X%d_a.pdf'%(X))
            plt.close(fig)
        im = np.argmax(p)
        if np.abs(x[im]) == np.max(x) or np.abs(y[im]) == (np.max(y)):
            x,y,p,loc_max = shower_plane_slice(e,X,21,21,scale4d,scale4d)
            if savefig:
                fig,axs = slice_figure(e,X,x,y,p,'sp')
                fig.savefig(path+'X%d_c.pdf'%(X))
                plt.close(fig)
        im = np.argmax(p)
        x,y,p,loc_max = shower_plane_slice(e,X,21,21,scale0_2d,scale0_2d,x[im],y[im])
        if savefig:
            fig,axs = slice_figure(e,X,x,y,p,'sp')
            fig.savefig(path+'X%d_b.pdf'%(X))
            plt.close(fig)
        max_vals.append(np.max(p))
        axis_points.append(loc_max)

    return Xsteps,axis_points,max_vals

def fit_track(e,axis_points,vals,nscale=1e0):
    weights = vals/np.max(vals)
    data=axis_points[:]
    data = [d/nscale for d in data] #km, to have more comparable step sizes
    x0=[0,0,np.deg2rad(e.zenith),np.deg2rad(e.azimuth)]
    res = minimize(dist_to_line_sum,args=(data,weights),x0=x0)
    zen_r = np.rad2deg(res.x[2])
    azi_r = np.rad2deg(res.x[3])
    print(res,zen_r,e.zenith,azi_r,e.azimuth)
    return zen_r,azi_r,[res.x[0]*nscale,res.x[1]*nscale,0]


def update_event(e,core,theta,phi,axp=None):
    #recalculate
    e.zenith = theta
    e.azimuth = phi
    theta = np.deg2rad(theta)
    phi = np.deg2rad(phi)
    e.core = e.core+core
    e.uA = np.array([np.cos(phi)*np.sin(theta),np.sin(phi)*np.sin(theta),np.cos(theta)])
    e.uAxB = np.cross(e.uA,e.uB)
    e.uAxB = e.uAxB/(np.dot(e.uAxB,e.uAxB))**0.5
    e.uAxAxB = np.cross(e.uA,e.uAxB)
    #antenna position
    for a in e.antennas:
        a.x -= core[0]
        a.y -= core[1]
        a.z -= core[2]
    if axp != None:
        for ap in axp:
            ap[0] -= core[0]
            ap[1] -= core[1]
            ap[2] -= core[2]

def longitudinal_figure(dist,Xs,p,mode='grammage'):
    fig, axs = plt.subplots(1,figsize=(6,5))
    if mode=='grammage':
        axs.plot(Xs,p/np.max(p),'o-')
        axs.set_xlabel('X (g/cm$^2$)')
    if mode=='distance':
        axs.plot(dist/1e3,p/np.max(p),'o-')
        axs.set_xlabel('distance from ground (km)')
    axs.grid()
    fig.tight_layout()
    return fig

def time_residuals(e,tlable=True):
    ds,tp,nsum,ssum,swidth,azi,x,y,sid = lateral_parameters(e,True,[0,0,0])
    fig, axs = plt.subplots(1,figsize=(6,5),sharex=True)
    tp = tp-np.min(tp)
    cut_outlier = ~((ds<200)&(tp > 10))
    axs.plot(ds,tp,'o')
    if tlable:
        for d,t,s in zip(ds,tp,sid):
            plt.text(d,t,s)
        # axs.text(ds,tp,sid)
    axs.set_xlabel('distance (m)')
    axs.set_ylabel('$\Delta t (ns)$')
    axs.grid()
    z = np.polyfit(ds[cut_outlier],tp[cut_outlier],3)
    pfit = np.poly1d(z)
    xfit = np.linspace(np.min(ds),np.max(ds),100)
    yfit = pfit(xfit)
    tres = tp - pfit(ds)
    sigma =np.std(tres[cut_outlier])
    axs.plot(xfit,yfit,label=r'pol3 fit, $\sigma=%.2f$ (ns)'%(sigma))
    axs.legend()
    fig.tight_layout()
    return fig,tres

def figure_3D(axis_points,max_vals,zen,azi,core,res = 0):
    fig = plt.figure(figsize=(5,9))
    # fig, axs = plt.subplots(1,2,figsize=(12,8))
    ax = fig.add_subplot(2,1,1,projection='3d')
    xp = [ap[0]/1e3 for ap in axis_points]
    yp = [ap[1]/1e3 for ap in axis_points]
    zp = [ap[2]/1e3 for ap in axis_points]
    max_vals = np.asarray(max_vals)
    ax.scatter(xp, yp, zp,c=max_vals,s=150*(max_vals/np.max(max_vals))**2,cmap='Spectral_r')
    ax = fig.add_subplot(2,1,2)
    core = np.array(core)
    theta = np.deg2rad(zen)
    phi = np.deg2rad(azi)
    u = np.array([np.cos(phi)*np.sin(theta),np.sin(phi)*np.sin(theta),np.cos(theta)])
    residuals = [dist_to_line(ap,core,u) for ap in axis_points]
    dist = [np.sum((ap-core)**2)**0.5 for ap in axis_points]
    ax.scatter(dist,residuals,c=max_vals,cmap='Spectral_r')
    ax.grid()
    # ax.plot(xl,yl,zl,'-')
    # ax.set_zlim(0,18)
    # ax.view_init(15, 10)
    fig.tight_layout()
    if res != 0:
        res.track_dis.append(dist)
        res.track_res.append(residuals)
        res.track_val.append(max_vals)
    return fig

class RITResult():
    """docstring for RITResult."""
    def __init__(self):
        super(RITResult, self).__init__()
        self.xmax_rit = []
        self.xmax = []
        self.profile_rit = []
        self.dX = []
        self.dl = []
        self.zenith_ini = []
        self.azimuth_ini = []
        self.core_ini = []
        self.dcore_rec = []
        self.zenith_rec = []
        self.azimuth_rec = []
        self.index = []
        self.isMC = []
        self.track_dis = []
        self.track_res =[]
        self.track_val =[]
        self.station_ids =[]
        self.station_x =[]
        self.station_y =[]
        self.station_z =[]
        self.station_maxE = []
        self.has_pulse = []

def fill_stations_propeties(e,res):
    x = np.array([a.x for a in e.antennas])
    y = np.array([a.y for a in e.antennas])
    z = np.array([a.z for a in e.antennas])
    ids = [a.name for a in e.antennas]
    maxE = np.array([np.max(a.E_AxB) for a in e.antennas])
    #has_pulse = np.array([np.max(a.has_pulse) for a in e.antennas])
    res.station_x.append(x)
    res.station_y.append(y)
    res.station_z.append(z)
    res.station_ids.append(ids)
    #res.has_pulse.append(has_pulse)

def reconstruction(e,outfile='', slice_outdir=None, Xlow=300, Xhigh=1000, N_X=15):
    res = RITResult()
    res.isMC.append(True)
    res.zenith_ini.append(e.zenith)
    res.azimuth_ini.append(e.azimuth)
    res.core_ini.append(e.core)

    set_pol_and_bp(e)

    #only use signal that have a signal in data
    fill_stations_propeties(e,res)
    Xs,axis_points,max_vals = get_axis_points(e,savefig=(slice_outdir is not None), path=slice_outdir, Xlow=Xlow, Xhigh=Xhigh, N_X=N_X)
    zen,azi,core = fit_track(e,axis_points,max_vals,1e2)
    fig = figure_3D(axis_points,max_vals,zen,azi,core,res)
    fig.savefig(outfile)
    update_event(e,core,zen,azi)
    ds,Xs,locs,p = shower_axis_slice(e)
    #result
    res.dX.append(Xs)
    res.dl.append(ds)
    res.profile_rit.append(p)
    res.xmax_rit.append(Xs[np.argmax(p)])
    res.azimuth_rec.append(e.azimuth)
    res.zenith_rec.append(e.zenith)
    res.dcore_rec.append(core)
    return res

if __name__ == "__main__":
    file = '../ZH_airshower/mysim.sry'
    ev = REvent(file)
    set_pol_and_bp(ev)
    X = 750

    dXref = atm.distance_to_slant_depth(np.deg2rad(ev.zenith),X,0)
    scale2d = dXref*np.tan(np.deg2rad(2.))
    xx,yy,p,km= shower_plane_slice(ev,X,21,21,scale2d,scale2d)
    
    slice_figure(ev,X,xx,yy,p,mode='sp')
    #plt.scatter(xx,yy,c=p)
    #plt.colorbar()
    plt.show()