diff --git a/fourier/mylib/plotting.py b/fourier/mylib/plotting.py index 8960b21..d29a32b 100644 --- a/fourier/mylib/plotting.py +++ b/fourier/mylib/plotting.py @@ -1,11 +1,17 @@ """ -Functions to simplify plotting +Functions to simplify plotting of fourier spectra """ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np -def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_amplitude=None, ax=None, freq_unit="Hz", freq_scaler=1): +def plot_spectrum( + spectrum, freqs, + plot_complex=False, plot_power=False, plot_amplitude=None, + freq_unit="Hz", freq_scaler=1, + title='Spectrum', xlabel='Frequency', ax=None, + **plot_kwargs + ): """ Plot a signal's spectrum on an Axis object""" plot_amplitude = plot_amplitude or (not plot_power and not plot_complex) alpha = 1 @@ -13,8 +19,8 @@ def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_a if ax is None: ax = plt.gca() - ax.set_title("Spectrum") - ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" )) + ax.set_title(title) + ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" )) ylabel = "" if plot_amplitude or plot_complex: ylabel = "Amplitude" @@ -26,70 +32,156 @@ def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_a if plot_complex: alpha = 0.5 - ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha) - ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha) + ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha, **plot_kwargs) + ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha, **plot_kwargs) if plot_power: - ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha) + ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha, **plot_kwargs) if plot_amplitude: - ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha) + ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha, **plot_kwargs) ax.legend() return ax -def plot_phase( spectrum, freqs, ylim_epsilon=0.5, ax=None, freq_unit="Hz", freq_scaler=1): +def plot_phase( + spectrum, freqs, + ylim_epsilon=0.5, ax=None, grid=True, + freq_unit="Hz", freq_scaler=1, xlabel='Frequency', + **plot_kwargs + ): + """ + Plot the phase of spectrum + """ + if ax is None: ax = plt.gca() ax.set_ylabel("Phase") - ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" )) + ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" )) + ax.grid(grid) - ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-') + ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-', **plot_kwargs) ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon) + major_tick = Multiple(2, np.pi, '\pi') + minor_tick = Multiple(12, np.pi, '\pi') + ax.yaxis.set_major_locator(major_tick.locator()) + ax.yaxis.set_major_formatter(major_tick.formatter()) + ax.yaxis.set_minor_locator(minor_tick.locator()) + return ax -def plot_signal( signal, sample_rate = 1, ax=None, time=None, time_unit="s", **kwargs): +def plot_signal( + signal, sample_rate = 1, + time=None, ax=None, + title='Signal', ylabel='Amplitude', + time_unit="s", xlabel='Time', + **kwargs): + """ + Plot the signal in the time domain + """ + if ax is None: ax = plt.gca() if time is None: time = np.arange(len(signal))/sample_rate - ax.set_title("Signal") - ax.set_xlabel("t" + (" ["+time_unit+"]" if time_unit else "" )) - ax.set_ylabel("A(t)") + ax.set_title(title) + ax.set_xlabel(xlabel + (" ["+time_unit+"]" if time_unit else "" )) + ax.set_ylabel(ylabel) ax.plot(time, signal, **kwargs) return ax -def plot_combined_spectrum(spectrum, freqs, - spectrum_kwargs={}, fig=None, gs=None, freq_scaler=1, freq_unit="Hz"): - """Plot both the frequencies and phase in one figure.""" +def plot_combined_spectrum( + spectrum, freqs, + fig=None, gs=None, + spectrum_ax=None, phase_ax=None, + spectrum_kwargs={}, phase_kwargs={}, + **shared_kwargs + ): + """ + Plot both the frequencies and phase in one figure. + """ - # configure plotting layout - if fig is None: - fig = plt.figure(figsize=(8, 16)) + ax1, ax2 = None, None + if spectrum_ax is not None or phase_ax is not None: + if spectrum_ax is not None: + ax1 = spectrum_ax + if phase_ax is not None: + ax2 = phase_ax - if gs is None: - gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0) + else: + # configure plotting layout + if fig is None: + fig = plt.figure(figsize=(8, 16)) - ax1 = fig.add_subplot(gs[:-1, -1]) - ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1) + if gs is None: + gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0) + + ax1 = fig.add_subplot(gs[:-1, -1]) + ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1) axes = np.array([ax1, ax2]) # plot the spectrum - plot_spectrum(spectrum, freqs, ax=ax1, freq_scaler=freq_scaler, freq_unit=freq_unit, **spectrum_kwargs) + plot_spectrum(spectrum, freqs, ax=ax1, **{ **shared_kwargs, **spectrum_kwargs}) # plot the phase - plot_phase(spectrum, freqs, ax=ax2, freq_scaler=freq_scaler, freq_unit=freq_unit) + plot_phase(spectrum, freqs, ax=ax2, **{ **shared_kwargs, **phase_kwargs}) ax1.xaxis.tick_top() [label.set_visible(False) for label in ax1.get_xticklabels()] return fig, axes +def multiple_formatter(denominator=2, number=np.pi, latex='\pi'): + """ + From https://stackoverflow.com/a/53586826 + """ + def gcd(a, b): + while b: + a, b = b, a%b + return a + + def _multiple_formatter(x, pos): + den = denominator + num = np.int(np.rint(den*x/number)) + com = gcd(num,den) + (num,den) = (int(num/com),int(den/com)) + if den==1: + if num==0: + return r'$0$' + if num==1: + return r'$%s$'%latex + elif num==-1: + return r'$-%s$'%latex + else: + return r'$%s%s$'%(num,latex) + else: + if num==1: + return r'$\frac{%s}{%s}$'%(latex,den) + elif num==-1: + return r'$\frac{-%s}{%s}$'%(latex,den) + else: + return r'$\frac{%s%s}{%s}$'%(num,latex,den) + return _multiple_formatter + +class Multiple: + """ + From https://stackoverflow.com/a/53586826 + """ + def __init__(self, denominator=2, number=np.pi, latex='\pi'): + self.denominator = denominator + self.number = number + self.latex = latex + + def locator(self): + return plt.MultipleLocator(self.number / self.denominator) + + def formatter(self): + return plt.FuncFormatter(multiple_formatter(self.denominator, self.number, self.latex))