m-thesis-introduction/airshower_beacon_simulation/lib/figlib.py

286 lines
8.8 KiB
Python
Raw Normal View History

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy import special
from scipy import optimize
from itertools import zip_longest
def expectation(x,pdfx):
dx = x[1]-x[0]
return np.sum(x*pdfx*dx)
def variance(x,pdfx):
mu = expectation(x,pdfx)
dx = x[1]-x[0]
return np.sum((x**2*pdfx*dx))-mu**2
def random_phase_sum_distribution(theta, sigma, s=1):
theta = np.asarray(theta)
ct = np.cos(theta)
st = np.sin(theta)
k = s/sigma
pipi = 2*np.pi
return (np.exp(-k**2/2)/pipi) + (
(pipi**-0.5)*k*np.exp(-(k*st)**2/2)) * (
(1.+special.erf(k*ct*2**-0.5))*ct/2)
def gaussian_phase_distribution(theta, sigma, s=1):
theta = np.asarray(theta)
k=s/sigma
return (2*np.pi)**-0.5*k*np.exp(-(k*theta)**2/2)
def phase_comparison_figure(
measured_phases,
true_phases,
plot_residuals=True,
f_beacon=None,
hist_kwargs={},
sc_kwargs={},
text_kwargs={},
colors=['blue', 'orange'],
legend_on_scatter=True,
secondary_axis='time',
fit_gaussian=False,
fit_randomphasesum=False,
mean_snr=None,
return_fit_info=False,
**fig_kwargs
):
"""
Create a figure comparing measured_phase against true_phase
by both plotting the values, and the residuals.
"""
default_fig_kwargs = dict(sharex=True)
default_hist_kwargs = dict(bins='sqrt', density=False, alpha=0.8, histtype='step')
default_text_kwargs = dict(fontsize=14, verticalalignment='top')
default_sc_kwargs = dict(alpha=0.6, ls='none')
do_hist_plot = hist_kwargs is not False
if hist_kwargs is False:
hist_kwargs = {}
do_scatter_plot = sc_kwargs is not False
if sc_kwargs is False:
sc_kwargs = {}
fig_kwargs = {**default_fig_kwargs, **fig_kwargs}
hist_kwargs = {**default_hist_kwargs, **hist_kwargs}
text_kwargs = {**default_text_kwargs, **text_kwargs}
sc_kwargs = {**default_sc_kwargs, **sc_kwargs}
fig, axs = plt.subplots(0+do_hist_plot+do_scatter_plot, 1, **fig_kwargs)
if not hasattr(axs, '__len__'):
axs = [axs]
if f_beacon and secondary_axis in ['phase', 'time']:
phase2time = lambda x: x/(2*np.pi*f_beacon)
time2phase = lambda x: 2*np.pi*x*f_beacon
if secondary_axis == 'time':
functions = (phase2time, time2phase)
label = 'Time $\\varphi/(2\\pi f_{beac})$ [ns]'
else:
functions = (time2phase, phase2time)
label = 'Phase $2\\pi t f_{beac}$ [rad]'
secax = axs[0].secondary_xaxis('top', functions=functions)
# Histogram
fit_info = {}
if do_hist_plot:
i=0
axs[i].set_ylabel("#")
this_kwargs = dict(
ax = axs[i],
text_kwargs=text_kwargs,
hist_kwargs={**hist_kwargs, **dict(label='Measured', color=colors[0], ls='solid')},
mean_snr=mean_snr,
fit_distr=[],
)
if fit_gaussian:
this_kwargs['fit_distr'].append('gaussian')
if fit_randomphasesum:
this_kwargs['fit_distr'].append('randomphasesum')
_, fit_info = fitted_histogram_figure(
measured_phases,
**this_kwargs
)
if not plot_residuals: # also plot the true clock phases
_bins = fit_info['bins']
axs[i].hist(true_phases, color=colors[1], label='Actual', ls='dashed', **{**hist_kwargs, **dict(bins=_bins)})
# Scatter plot
if do_scatter_plot:
i=1
axs[i].set_ylabel("Antenna no.")
axs[i].plot(measured_phases, np.arange(len(measured_phases)), marker='x' if plot_residuals else '3', color=colors[0], label='Measured', **sc_kwargs)
if not plot_residuals: # also plot the true clock phases
axs[i].plot(true_phases, np.arange(len(true_phases)), marker='4', color=colors[1], label='Actual', **sc_kwargs)
if not plot_residuals and legend_on_scatter:
axs[i].legend()
fig.tight_layout()
if return_fit_info:
return fig, fit_info
return fig
def fitted_histogram_figure(
amplitudes,
fit_distr = None,
2023-02-20 17:37:38 +01:00
calc_chisq = True,
text_kwargs={},
hist_kwargs={},
mean_snr = None,
ax = None,
**fig_kwargs
):
"""
Create a figure showing $amplitudes$ as a histogram.
If fit_distr is a (list of) string, also fit the respective
distribution function and show the parameters on the figure.
"""
default_hist_kwargs = dict(bins='sqrt', density=False, alpha=0.8, histtype='step', label='hist')
default_text_kwargs = dict(fontsize=14, verticalalignment='top')
if isinstance(fit_distr, str):
fit_distr = [fit_distr]
hist_kwargs = {**default_hist_kwargs, **hist_kwargs}
text_kwargs = {**default_text_kwargs, **text_kwargs}
if ax is None:
fig, ax = plt.subplots(1, 1, **fig_kwargs)
else:
fig = ax.get_figure()
text_kwargs['transform'] = ax.transAxes
counts, bins, _patches = ax.hist(amplitudes, **hist_kwargs)
fit_info = []
if fit_distr:
min_x = min(amplitudes)
max_x = max(amplitudes)
bin_centers = bins[:-1] + np.diff(bins) / 2
dx = bins[1] - bins[0]
scale = len(amplitudes) * dx
xs = np.linspace(min_x, max_x)
for distr in fit_distr:
2023-02-20 17:49:15 +01:00
fit_params2text_params = lambda x: x
fit_ys = None
fit_params = None
cdf = None
if 'rice' == distr:
name = "Rice"
param_names = [ "$\\nu$", "$\\sigma$" ]
distr_func = stats.rice
fit_params2text_params = lambda x: (x[0]*x[1], x[1])
elif 'gaussian' == distr:
name = "Norm"
param_names = [ "$\\mu$", "$\\sigma$" ]
distr_func = stats.norm
elif 'rayleigh' == distr:
name = "Rayleigh"
param_names = [ "$\\sigma$" ]
distr_func = stats.rayleigh
fit_params2text_params = lambda x: (x[0]+x[1]/2,)
elif 'randomphasesum' == distr:
name = "RandPhaseS"
param_names = [ "$\\sigma$", 's']
pdf = random_phase_sum_distribution
bounds = ((0,0.9999), (np.inf,1))
fit_params, pcov = optimize.curve_fit(pdf, bin_centers, counts, bounds=bounds)
fit_ys = pdf( xs, *fit_params)
fit_params2text_params = lambda x: (x[1], x[0])
elif 'gaussphase' == distr:
name = 'GaussPhase'
param_names = [ "$\\sigma$", 's']
pdf = gaussian_phase_distribution
bounds = ((0,0.9999), (np.inf,1))
fit_params, pcov = optimize.curve_fit(pdf, bin_centers, counts, bounds=bounds)
fit_ys = pdf( xs, *fit_params)
fit_params2text_params = lambda x: (x[1], x[0])
else:
raise ValueError('Unknown distribution function '+distr)
label = name +"(" + ','.join(param_names) + ')'
if fit_ys is None:
fit_params = distr_func.fit(amplitudes)
fit_ys = scale * distr_func.pdf(xs, *fit_params)
cdf = distr_func.cdf
ax.plot(xs, fit_ys, label=label)
2023-02-20 17:37:38 +01:00
chisq_strs = []
if calc_chisq and cdf:
ct = np.diff(cdf(bins, *fit_params))*np.sum(counts)
2023-02-20 17:37:38 +01:00
c2t = stats.chisquare(counts, ct, ddof=len(fit_params))
chisq_strs = [
f"$\\chi^2$/dof = {c2t[0]: .2g}/{len(fit_params)}"
]
# change parameters if needed
text_fit_params = fit_params2text_params(fit_params)
text_str = "\n".join(
[label]
+
[ f"{param} = {value: .2e}" for param, value in zip_longest(param_names, text_fit_params, fillvalue='?') ]
2023-02-20 17:37:38 +01:00
+
chisq_strs
)
this_info = {
'name': name,
'param_names': param_names,
'param_values': text_fit_params,
'text_str': text_str,
}
2023-02-20 17:37:38 +01:00
if chisq_strs:
this_info['chisq'] = c2t[0]
this_info['dof'] = len(fit_params)
fit_info.append(this_info)
loc = (0.02, 0.95)
ax.text(*loc, "\n\n".join([info['text_str'] for info in fit_info]), **{**text_kwargs, **dict(ha='left')})
if mean_snr:
text_str = f"$\\langle SNR \\rangle$ = {mean_snr: .1e}"
loc = (0.98, 0.95)
ax.text(*loc, text_str, **{**text_kwargs, **dict(ha='right')})
return fig, dict(fit_info=fit_info, counts=counts, bins=bins)