ZH: rit attempt multiprocessing in get_axis_points

This commit is contained in:
Eric Teunis de Boone 2023-01-11 19:43:48 +01:00
parent c29bb2ee50
commit 3c99ac7118

View file

@ -10,6 +10,11 @@ from scipy.optimize import curve_fit,minimize
import pandas as pd import pandas as pd
import os import os
try:
from joblib import Parallel, delayed
except:
Parallel = None
plt.rcParams.update({'font.size': 16}) plt.rcParams.update({'font.size': 16})
atm = AtmoCal() atm = AtmoCal()
@ -147,16 +152,15 @@ def dist_to_line_sum(param,data,weights):
return dsum/len(data) return dsum/len(data)
def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15): def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15):
axis_points = []
max_vals = []
Xsteps = np.linspace(Xlow, Xhigh, N_X) Xsteps = np.linspace(Xlow, Xhigh, N_X)
zgr=zgr+e.core[2] #not exact zgr=zgr+e.core[2] #not exact
dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr) dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr)
scale2d = dXref*np.tan(np.deg2rad(2.)) scale2d = dXref*np.tan(np.deg2rad(2.))
scale4d = dXref*np.tan(np.deg2rad(4.)) scale4d = dXref*np.tan(np.deg2rad(4.))
scale0_2d=dXref*np.tan(np.deg2rad(0.2)) scale0_2d=dXref*np.tan(np.deg2rad(0.2))
for X in Xsteps:
print(X) def loop_func(X):
print("Starting", X)
x,y,p,loc_max = shower_plane_slice(e,X,21,21,scale2d,scale2d) x,y,p,loc_max = shower_plane_slice(e,X,21,21,scale2d,scale2d)
if savefig: if savefig:
fig,axs = slice_figure(e,X,x,y,p,'sp') fig,axs = slice_figure(e,X,x,y,p,'sp')
@ -175,8 +179,22 @@ def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15):
fig,axs = slice_figure(e,X,x,y,p,'sp') fig,axs = slice_figure(e,X,x,y,p,'sp')
fig.savefig(path+'X%d_b.pdf'%(X)) fig.savefig(path+'X%d_b.pdf'%(X))
plt.close(fig) plt.close(fig)
max_vals.append(np.max(p))
axis_points.append(loc_max) print("Finished", X)
return np.max(p), loc_max
if Parallel:
n_jobs= None # if None change with `with parallel_backend`
res = Parallel(n_jobs=n_jobs)(delayed(loop_func)(X) for X in Xsteps)
max_vals = [ item[0] for item in res ]
axis_points = [ item[1] for item in res ]
else:
for X in Xsteps:
_max_val, _loc_max = loop_func(X)
max_vals.append(_max_val)
axis_points.append(_loc_max)
return Xsteps,axis_points,max_vals return Xsteps,axis_points,max_vals