""" Functions to simplify plotting """ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_amplitude=None, ax=None, freq_unit="Hz", freq_scaler=1): """ Plot a signal's spectrum on an Axis object""" plot_amplitude = plot_amplitude or (not plot_power and not plot_complex) alpha = 1 if ax is None: ax = plt.gca() ax.set_title("Spectrum") ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" )) ylabel = "" if plot_amplitude or plot_complex: ylabel = "Amplitude" if plot_power: if ylabel: ylabel += "|" ylabel += "Power" ax.set_ylabel(ylabel) if plot_complex: alpha = 0.5 ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha) ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha) if plot_power: ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha) if plot_amplitude: ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha) ax.legend() return ax def plot_phase( spectrum, freqs, ylim_epsilon=0.5, ax=None, freq_unit="Hz", freq_scaler=1): if ax is None: ax = plt.gca() ax.set_ylabel("Phase") ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" )) ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-') ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon) return ax def plot_signal( signal, sample_rate = 1, ax=None, time=None, time_unit="s", **kwargs): if ax is None: ax = plt.gca() if time is None: time = np.arange(len(signal))/sample_rate ax.set_title("Signal") ax.set_xlabel("t" + (" ["+time_unit+"]" if time_unit else "" )) ax.set_ylabel("A(t)") ax.plot(time, signal, **kwargs) return ax def plot_combined_spectrum(spectrum, freqs, spectrum_kwargs={}, fig=None, gs=None, freq_scaler=1, freq_unit="Hz"): """Plot both the frequencies and phase in one figure.""" # configure plotting layout if fig is None: fig = plt.figure(figsize=(8, 16)) if gs is None: gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0) ax1 = fig.add_subplot(gs[:-1, -1]) ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1) axes = np.array([ax1, ax2]) # plot the spectrum plot_spectrum(spectrum, freqs, ax=ax1, freq_scaler=freq_scaler, freq_unit=freq_unit, **spectrum_kwargs) # plot the phase plot_phase(spectrum, freqs, ax=ax2, freq_scaler=freq_scaler, freq_unit=freq_unit) ax1.xaxis.tick_top() [label.set_visible(False) for label in ax1.get_xticklabels()] return fig, axes