ZH: rit rewrite with unpacking in get_axis_points

This commit is contained in:
Eric Teunis de Boone 2023-01-13 18:21:03 +01:00
parent 0ac1deff34
commit d1700da445

View file

@ -14,6 +14,7 @@ try:
from joblib import Parallel, delayed
except:
Parallel = None
delayed = lambda x: x
plt.rcParams.update({'font.size': 16})
@ -151,7 +152,7 @@ def dist_to_line_sum(param,data,weights):
# print('%.2e %.2e %.2e %.2e %.2e'%(x0,y0,theta,phi,dsum))
return dsum/len(data)
def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15):
def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15, n_jobs=None):
Xsteps = np.linspace(Xlow, Xhigh, N_X)
zgr=zgr+e.core[2] #not exact
dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr)
@ -183,18 +184,14 @@ def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15):
print("Finished", X)
return np.max(p), loc_max
res = (delayed(loop_func)(X) for X in Xsteps)
if Parallel:
n_jobs= None # if None change with `with parallel_backend`
res = Parallel(n_jobs=n_jobs)(delayed(loop_func)(X) for X in Xsteps)
#if n_jobs is None change with `with parallel_backend`
res = Parallel(n_jobs=n_jobs)(res)
max_vals = [ item[0] for item in res ]
axis_points = [ item[1] for item in res ]
else:
for X in Xsteps:
_max_val, _loc_max = loop_func(X)
max_vals.append(_max_val)
axis_points.append(_loc_max)
# unpack loop results
max_vals, axis_points = zip(*res)
return Xsteps,axis_points,max_vals