ZH: rit rewrite with unpacking in get_axis_points

This commit is contained in:
Eric Teunis de Boone 2023-01-13 18:21:03 +01:00
parent 0ac1deff34
commit d1700da445

View file

@ -14,6 +14,7 @@ try:
from joblib import Parallel, delayed from joblib import Parallel, delayed
except: except:
Parallel = None Parallel = None
delayed = lambda x: x
plt.rcParams.update({'font.size': 16}) plt.rcParams.update({'font.size': 16})
@ -151,7 +152,7 @@ def dist_to_line_sum(param,data,weights):
# print('%.2e %.2e %.2e %.2e %.2e'%(x0,y0,theta,phi,dsum)) # print('%.2e %.2e %.2e %.2e %.2e'%(x0,y0,theta,phi,dsum))
return dsum/len(data) return dsum/len(data)
def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15): def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15, n_jobs=None):
Xsteps = np.linspace(Xlow, Xhigh, N_X) Xsteps = np.linspace(Xlow, Xhigh, N_X)
zgr=zgr+e.core[2] #not exact zgr=zgr+e.core[2] #not exact
dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr) dXref = atm.distance_to_slant_depth(np.deg2rad(e.zenith),750,zgr)
@ -183,18 +184,14 @@ def get_axis_points(e,savefig=True,path="",zgr=0,Xlow=300, Xhigh=1000, N_X=15):
print("Finished", X) print("Finished", X)
return np.max(p), loc_max return np.max(p), loc_max
res = (delayed(loop_func)(X) for X in Xsteps)
if Parallel: if Parallel:
n_jobs= None # if None change with `with parallel_backend` #if n_jobs is None change with `with parallel_backend`
res = Parallel(n_jobs=n_jobs)(delayed(loop_func)(X) for X in Xsteps) res = Parallel(n_jobs=n_jobs)(res)
max_vals = [ item[0] for item in res ] # unpack loop results
axis_points = [ item[1] for item in res ] max_vals, axis_points = zip(*res)
else:
for X in Xsteps:
_max_val, _loc_max = loop_func(X)
max_vals.append(_max_val)
axis_points.append(_loc_max)
return Xsteps,axis_points,max_vals return Xsteps,axis_points,max_vals