Merge branch 'rit-joblib' into main

This commit is contained in:
Eric Teunis de Boone 2023-01-13 18:54:19 +01:00
commit 76b9e99936
2 changed files with 40 additions and 21 deletions

View file

@ -11,6 +11,7 @@ from mpl_toolkits.mplot3d import Axes3D # required for projection='3d' on old ma
import numpy as np import numpy as np
from os import path from os import path
import pickle import pickle
import joblib
from earsim import REvent from earsim import REvent
from atmocal import AtmoCal from atmocal import AtmoCal
@ -73,7 +74,8 @@ if __name__ == "__main__":
ev.antennas[i].t += total_clock_time ev.antennas[i].t += total_clock_time
N_X, Xlow, Xhigh = 23, 100, 1200 N_X, Xlow, Xhigh = 23, 100, 1200
res = rit.reconstruction(ev, outfile=fig_subdir+'/fig.pdf', slice_outdir=fig_subdir+'/', Xlow=Xlow, N_X=N_X, Xhigh=Xhigh) with joblib.parallel_backend("loky"):
res = rit.reconstruction(ev, outfile=fig_subdir+'/fig.pdf', slice_outdir=fig_subdir+'/', Xlow=Xlow, N_X=N_X, Xhigh=Xhigh)
## Save a pickle ## Save a pickle
with open(pickle_fname, 'wb') as fp: with open(pickle_fname, 'wb') as fp:

View file

@ -10,6 +10,12 @@ from scipy.optimize import curve_fit,minimize
import pandas as pd import pandas as pd
import os import os
try:
from joblib import Parallel, delayed
except:
Parallel = None
delayed = lambda x: x
plt.rcParams.update({'font.size': 16}) plt.rcParams.update({'font.size': 16})
atm = AtmoCal() atm = AtmoCal()
@ -79,26 +85,28 @@ def shower_axis_slice(e,Xb=200,Xe=1200,dX=2,zgr=0):
p = np.asanyarray(p) p = np.asanyarray(p)
return ds,Xs,locs,p return ds,Xs,locs,p
def shower_plane_slice(e,X=750.,Nx=10,Ny=10,wx=1e3,wy=1e3,xoff=0,yoff=0,zgr=0): def shower_plane_slice(e,X=750.,Nx=10,Ny=10,wx=1e3,wy=1e3,xoff=0,yoff=0,zgr=0,n_jobs=None):
zgr = zgr + e.core[2] zgr = zgr + e.core[2]
dX = atm.distance_to_slant_depth(np.deg2rad(e.zenith),X,zgr) dX = atm.distance_to_slant_depth(np.deg2rad(e.zenith),X,zgr)
x = np.linspace(-wx,wx,Nx) x = np.linspace(-wx,wx,Nx)
y = np.linspace(-wy,wy,Ny) y = np.linspace(-wy,wy,Ny)
xx = []
yy = [] def loop_func(x_, y_, xoff=xoff, yoff=yoff):
p = []
locs = []
for x_ in x:
for y_ in y:
loc = (x_+xoff)* e.uAxB + (y_+yoff)*e.uAxAxB + dX *e.uA loc = (x_+xoff)* e.uAxB + (y_+yoff)*e.uAxAxB + dX *e.uA
locs.append(loc) locs.append(loc)
P,t_,pulses_,wav,twav = pow_and_time(loc,e) P,t_,pulses_,wav,twav = pow_and_time(loc,e)
xx.append(x_+xoff)
yy.append(y_+yoff) return x_+xoff, y_+yoff, P, locs
p.append(P)
xx = np.asarray(xx) res = ( delayed(loop_func)(x_, y_) for x_ in x for y_ in y)
yy = np.asarray(yy)
p = np.asanyarray(p) if Parallel:
#if n_jobs is None change with `with parallel_backend`
res = Parallel(n_jobs=n_jobs)(res)
# unpack loop results
xx, yy, p, locs = zip(*res)
return xx,yy,p,locs[np.argmax(p)] return xx,yy,p,locs[np.argmax(p)]
def slice_figure(e,X,xx,yy,p,mode='horizontal'): def slice_figure(e,X,xx,yy,p,mode='horizontal'):
@ -146,17 +154,16 @@ def dist_to_line_sum(param,data,weights):
# print('%.2e %.2e %.2e %.2e %.2e'%(x0,y0,theta,phi,dsum)) # print('%.2e %.2e %.2e %.2e %.2e'%(x0,y0,theta,phi,dsum))
return dsum/len(data) return dsum/len(data)
def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15): def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15, n_jobs=None):
axis_points = []
max_vals = []
Xsteps = np.linspace(Xlow, Xhigh, N_X) Xsteps = np.linspace(Xlow, Xhigh, N_X)
zgr=zgr+e.core[2] #not exact zgr=zgr+e.core[2] #not exact
dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr) dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr)
scale2d = dXref*np.tan(np.deg2rad(2.)) scale2d = dXref*np.tan(np.deg2rad(2.))
scale4d = dXref*np.tan(np.deg2rad(4.)) scale4d = dXref*np.tan(np.deg2rad(4.))
scale0_2d=dXref*np.tan(np.deg2rad(0.2)) scale0_2d=dXref*np.tan(np.deg2rad(0.2))
for X in Xsteps:
print(X) def loop_func(X):
print("Starting", X)
x,y,p,loc_max = shower_plane_slice(e,X,21,21,scale2d,scale2d) x,y,p,loc_max = shower_plane_slice(e,X,21,21,scale2d,scale2d)
if savefig: if savefig:
fig,axs = slice_figure(e,X,x,y,p,'sp') fig,axs = slice_figure(e,X,x,y,p,'sp')
@ -175,8 +182,18 @@ def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15):
fig,axs = slice_figure(e,X,x,y,p,'sp') fig,axs = slice_figure(e,X,x,y,p,'sp')
fig.savefig(path+'X%d_b.pdf'%(X)) fig.savefig(path+'X%d_b.pdf'%(X))
plt.close(fig) plt.close(fig)
max_vals.append(np.max(p))
axis_points.append(loc_max) print("Finished", X)
return np.max(p), loc_max
res = (delayed(loop_func)(X) for X in Xsteps)
if Parallel:
#if n_jobs is None change with `with parallel_backend`
res = Parallel(n_jobs=n_jobs)(res)
# unpack loop results
max_vals, axis_points = zip(*res)
return Xsteps,axis_points,max_vals return Xsteps,axis_points,max_vals