""" Functions to simplify plotting of fourier spectra """ # vim: fdm=indent ts=4 import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np from .fft import ft_spectrum def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_amplitude=None, freq_unit="Hz", freq_scaler=1, title='Spectrum', xlabel='Frequency', ax=None, **plot_kwargs ): """ Plot a signal's spectrum on an Axis object""" plot_amplitude = plot_amplitude or (not plot_power and not plot_complex) alpha = 1 if ax is None: ax = plt.gca() ax.set_title(title) ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" )) ylabel = "" if plot_amplitude or plot_complex: ylabel = "Amplitude" if plot_power: if ylabel: ylabel += "|" ylabel += "Power" ax.set_ylabel(ylabel) if plot_complex: alpha = 0.5 ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha, **plot_kwargs) ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha, **plot_kwargs) if plot_power: ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha, **plot_kwargs) if plot_amplitude: ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha, **plot_kwargs) ax.legend() return ax def plot_phase( spectrum, freqs, ylim_epsilon=0.5, ax=None, grid=True, freq_unit="Hz", freq_scaler=1, xlabel='Frequency', major_divider=2, minor_divider=12, **plot_kwargs ): """ Plot the phase of spectrum """ if ax is None: ax = plt.gca() ax.set_ylabel("Phase") ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" )) ax.grid(grid) ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-', **plot_kwargs) ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon) axis_pi_ticker(ax.yaxis, major_divider=major_divider, minor_divider=minor_divider) return ax def axis_pi_ticker(axis, major_divider=2, minor_divider=12): major_tick = Multiple(major_divider, np.pi, '\pi') minor_tick = Multiple(minor_divider, np.pi, '\pi') axis.set_major_locator(major_tick.locator()) axis.set_major_formatter(major_tick.formatter()) axis.set_minor_locator(minor_tick.locator()) return axis def plot_signal( signal, sample_rate = 1, time=None, ax=None, title='Signal', ylabel='Amplitude', time_unit="s", xlabel='Time', **kwargs): """ Plot the signal in the time domain """ if ax is None: ax = plt.gca() if time is None: time = np.arange(len(signal))/sample_rate ax.set_title(title) ax.set_xlabel(xlabel + (" ["+time_unit+"]" if time_unit else "" )) ax.set_ylabel(ylabel) ax.plot(time, signal, **kwargs) return ax def plot_combined_spectrum( spectrum, freqs, fig=None, gs=None, spectrum_ax=None, phase_ax=None, spectrum_kwargs={}, phase_kwargs={}, **shared_kwargs ): """ Plot both the frequencies and phase in one figure. """ ax1, ax2 = None, None if spectrum_ax is not None or phase_ax is not None: if spectrum_ax is not None: ax1 = spectrum_ax if phase_ax is not None: ax2 = phase_ax else: # configure plotting layout if fig is None: fig = plt.figure(figsize=(8, 16)) if gs is None: gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0) ax1 = fig.add_subplot(gs[:-1, -1]) ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1) axes = np.array([ax1, ax2]) # plot the spectrum plot_spectrum(spectrum, freqs, ax=ax1, **{ **shared_kwargs, **spectrum_kwargs}) # plot the phase plot_phase(spectrum, freqs, ax=ax2, **{ **shared_kwargs, **phase_kwargs}) ax1.xaxis.tick_top() [label.set_visible(False) for label in ax1.get_xticklabels()] return fig, axes def plot_signal_and_spectrum( signal, sample_rate=1, title=None, signal_kwargs={}, ft_kwargs={}, spectrum_kwargs={}, phase_kwargs={'major_divider':1, 'minor_divider': 6}, **phase_spectrum_kwargs ): """ Create a figure showing both the signal and the combined spectrum. """ fig = plt.figure(figsize=(16, 4)) if title: fig.suptitle(title) # setup plot layout gs0 = gridspec.GridSpec(1, 2, figure=fig) gs00 = gs0[0].subgridspec(1, 1) gs01 = gs0[1].subgridspec(2, 1, height_ratios=[3,1], hspace=0) # plot the signal ax1 = fig.add_subplot(gs00[0, 0]) plot_signal(signal, sample_rate, ax=ax1, **signal_kwargs) # plot spectrum signal_fft, freqs = ft_spectrum(signal, sample_rate, **ft_kwargs) _, (ax2, ax3) = plot_combined_spectrum( signal_fft, freqs, fig=fig, gs=gs01, spectrum_kwargs=spectrum_kwargs, phase_kwargs=phase_kwargs, **phase_spectrum_kwargs ) # return the axes axes = np.array([ax1, ax2, ax3]) return fig, axes def multiple_formatter(denominator=2, number=np.pi, latex='\pi'): """ From https://stackoverflow.com/a/53586826 """ def gcd(a, b): while b: a, b = b, a%b return a def _multiple_formatter(x, pos): den = denominator num = np.int(np.rint(den*x/number)) com = gcd(num,den) (num,den) = (int(num/com),int(den/com)) if den==1: if num==0: return r'$0$' if num==1: return r'$%s$'%latex elif num==-1: return r'$-%s$'%latex else: return r'$%s%s$'%(num,latex) else: if num==1: return r'$\frac{%s}{%s}$'%(latex,den) elif num==-1: return r'$\frac{-%s}{%s}$'%(latex,den) else: return r'$\frac{%s%s}{%s}$'%(num,latex,den) return _multiple_formatter class Multiple: """ From https://stackoverflow.com/a/53586826 """ def __init__(self, denominator=2, number=np.pi, latex='\pi'): self.denominator = denominator self.number = number self.latex = latex def locator(self): return plt.MultipleLocator(self.number / self.denominator) def formatter(self): return plt.FuncFormatter(multiple_formatter(self.denominator, self.number, self.latex))