m-thesis-introduction/fourier/mylib/plotting.py

187 lines
5.1 KiB
Python

"""
Functions to simplify plotting of fourier spectra
"""
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
def plot_spectrum(
spectrum, freqs,
plot_complex=False, plot_power=False, plot_amplitude=None,
freq_unit="Hz", freq_scaler=1,
title='Spectrum', xlabel='Frequency', ax=None,
**plot_kwargs
):
""" Plot a signal's spectrum on an Axis object"""
plot_amplitude = plot_amplitude or (not plot_power and not plot_complex)
alpha = 1
if ax is None:
ax = plt.gca()
ax.set_title(title)
ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
ylabel = ""
if plot_amplitude or plot_complex:
ylabel = "Amplitude"
if plot_power:
if ylabel:
ylabel += "|"
ylabel += "Power"
ax.set_ylabel(ylabel)
if plot_complex:
alpha = 0.5
ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha, **plot_kwargs)
ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha, **plot_kwargs)
if plot_power:
ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha, **plot_kwargs)
if plot_amplitude:
ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha, **plot_kwargs)
ax.legend()
return ax
def plot_phase(
spectrum, freqs,
ylim_epsilon=0.5, ax=None, grid=True,
freq_unit="Hz", freq_scaler=1, xlabel='Frequency',
**plot_kwargs
):
"""
Plot the phase of spectrum
"""
if ax is None:
ax = plt.gca()
ax.set_ylabel("Phase")
ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
ax.grid(grid)
ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-', **plot_kwargs)
ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon)
major_tick = Multiple(2, np.pi, '\pi')
minor_tick = Multiple(12, np.pi, '\pi')
ax.yaxis.set_major_locator(major_tick.locator())
ax.yaxis.set_major_formatter(major_tick.formatter())
ax.yaxis.set_minor_locator(minor_tick.locator())
return ax
def plot_signal(
signal, sample_rate = 1,
time=None, ax=None,
title='Signal', ylabel='Amplitude',
time_unit="s", xlabel='Time',
**kwargs):
"""
Plot the signal in the time domain
"""
if ax is None:
ax = plt.gca()
if time is None:
time = np.arange(len(signal))/sample_rate
ax.set_title(title)
ax.set_xlabel(xlabel + (" ["+time_unit+"]" if time_unit else "" ))
ax.set_ylabel(ylabel)
ax.plot(time, signal, **kwargs)
return ax
def plot_combined_spectrum(
spectrum, freqs,
fig=None, gs=None,
spectrum_ax=None, phase_ax=None,
spectrum_kwargs={}, phase_kwargs={},
**shared_kwargs
):
"""
Plot both the frequencies and phase in one figure.
"""
ax1, ax2 = None, None
if spectrum_ax is not None or phase_ax is not None:
if spectrum_ax is not None:
ax1 = spectrum_ax
if phase_ax is not None:
ax2 = phase_ax
else:
# configure plotting layout
if fig is None:
fig = plt.figure(figsize=(8, 16))
if gs is None:
gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0)
ax1 = fig.add_subplot(gs[:-1, -1])
ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1)
axes = np.array([ax1, ax2])
# plot the spectrum
plot_spectrum(spectrum, freqs, ax=ax1, **{ **shared_kwargs, **spectrum_kwargs})
# plot the phase
plot_phase(spectrum, freqs, ax=ax2, **{ **shared_kwargs, **phase_kwargs})
ax1.xaxis.tick_top()
[label.set_visible(False) for label in ax1.get_xticklabels()]
return fig, axes
def multiple_formatter(denominator=2, number=np.pi, latex='\pi'):
"""
From https://stackoverflow.com/a/53586826
"""
def gcd(a, b):
while b:
a, b = b, a%b
return a
def _multiple_formatter(x, pos):
den = denominator
num = np.int(np.rint(den*x/number))
com = gcd(num,den)
(num,den) = (int(num/com),int(den/com))
if den==1:
if num==0:
return r'$0$'
if num==1:
return r'$%s$'%latex
elif num==-1:
return r'$-%s$'%latex
else:
return r'$%s%s$'%(num,latex)
else:
if num==1:
return r'$\frac{%s}{%s}$'%(latex,den)
elif num==-1:
return r'$\frac{-%s}{%s}$'%(latex,den)
else:
return r'$\frac{%s%s}{%s}$'%(num,latex,den)
return _multiple_formatter
class Multiple:
"""
From https://stackoverflow.com/a/53586826
"""
def __init__(self, denominator=2, number=np.pi, latex='\pi'):
self.denominator = denominator
self.number = number
self.latex = latex
def locator(self):
return plt.MultipleLocator(self.number / self.denominator)
def formatter(self):
return plt.FuncFormatter(multiple_formatter(self.denominator, self.number, self.latex))