2022-10-31 18:16:03 +01:00
|
|
|
"""
|
2022-11-02 16:27:57 +01:00
|
|
|
Functions to simplify plotting of fourier spectra
|
2022-10-31 18:16:03 +01:00
|
|
|
"""
|
2022-11-02 19:05:53 +01:00
|
|
|
# vim: fdm=indent ts=4
|
|
|
|
|
2022-10-31 18:16:03 +01:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import matplotlib.gridspec as gridspec
|
|
|
|
import numpy as np
|
|
|
|
|
2022-11-02 19:05:53 +01:00
|
|
|
from .fft import ft_spectrum
|
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
def plot_spectrum(
|
|
|
|
spectrum, freqs,
|
|
|
|
plot_complex=False, plot_power=False, plot_amplitude=None,
|
|
|
|
freq_unit="Hz", freq_scaler=1,
|
|
|
|
title='Spectrum', xlabel='Frequency', ax=None,
|
|
|
|
**plot_kwargs
|
|
|
|
):
|
2022-10-31 18:16:03 +01:00
|
|
|
""" Plot a signal's spectrum on an Axis object"""
|
|
|
|
plot_amplitude = plot_amplitude or (not plot_power and not plot_complex)
|
|
|
|
alpha = 1
|
|
|
|
|
|
|
|
if ax is None:
|
|
|
|
ax = plt.gca()
|
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
ax.set_title(title)
|
|
|
|
ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
|
2022-10-31 18:16:03 +01:00
|
|
|
ylabel = ""
|
|
|
|
if plot_amplitude or plot_complex:
|
|
|
|
ylabel = "Amplitude"
|
|
|
|
if plot_power:
|
|
|
|
if ylabel:
|
|
|
|
ylabel += "|"
|
|
|
|
ylabel += "Power"
|
|
|
|
ax.set_ylabel(ylabel)
|
|
|
|
|
|
|
|
if plot_complex:
|
|
|
|
alpha = 0.5
|
2022-11-02 16:27:57 +01:00
|
|
|
ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha, **plot_kwargs)
|
|
|
|
ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha, **plot_kwargs)
|
2022-10-31 18:16:03 +01:00
|
|
|
|
|
|
|
if plot_power:
|
2022-11-02 16:27:57 +01:00
|
|
|
ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha, **plot_kwargs)
|
2022-10-31 18:16:03 +01:00
|
|
|
|
|
|
|
if plot_amplitude:
|
2022-11-02 16:27:57 +01:00
|
|
|
ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha, **plot_kwargs)
|
2022-10-31 18:16:03 +01:00
|
|
|
|
|
|
|
ax.legend()
|
|
|
|
|
|
|
|
return ax
|
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
def plot_phase(
|
|
|
|
spectrum, freqs,
|
|
|
|
ylim_epsilon=0.5, ax=None, grid=True,
|
|
|
|
freq_unit="Hz", freq_scaler=1, xlabel='Frequency',
|
2022-11-02 19:05:53 +01:00
|
|
|
major_divider=2, minor_divider=12,
|
2022-11-02 16:27:57 +01:00
|
|
|
**plot_kwargs
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Plot the phase of spectrum
|
|
|
|
"""
|
|
|
|
|
2022-10-31 18:16:03 +01:00
|
|
|
if ax is None:
|
|
|
|
ax = plt.gca()
|
|
|
|
|
|
|
|
ax.set_ylabel("Phase")
|
2022-11-02 16:27:57 +01:00
|
|
|
ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
|
|
|
|
ax.grid(grid)
|
2022-10-31 18:16:03 +01:00
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-', **plot_kwargs)
|
2022-10-31 18:16:03 +01:00
|
|
|
ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon)
|
|
|
|
|
2022-11-02 19:05:53 +01:00
|
|
|
axis_pi_ticker(ax.yaxis, major_divider=major_divider, minor_divider=minor_divider)
|
2022-11-02 16:27:57 +01:00
|
|
|
|
2022-10-31 18:16:03 +01:00
|
|
|
return ax
|
|
|
|
|
2022-11-02 19:05:53 +01:00
|
|
|
def axis_pi_ticker(axis, major_divider=2, minor_divider=12):
|
|
|
|
|
|
|
|
major_tick = Multiple(major_divider, np.pi, '\pi')
|
|
|
|
minor_tick = Multiple(minor_divider, np.pi, '\pi')
|
|
|
|
|
|
|
|
axis.set_major_locator(major_tick.locator())
|
|
|
|
axis.set_major_formatter(major_tick.formatter())
|
|
|
|
axis.set_minor_locator(minor_tick.locator())
|
|
|
|
|
|
|
|
return axis
|
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
def plot_signal(
|
|
|
|
signal, sample_rate = 1,
|
|
|
|
time=None, ax=None,
|
|
|
|
title='Signal', ylabel='Amplitude',
|
|
|
|
time_unit="s", xlabel='Time',
|
|
|
|
**kwargs):
|
|
|
|
"""
|
|
|
|
Plot the signal in the time domain
|
|
|
|
"""
|
|
|
|
|
2022-10-31 18:16:03 +01:00
|
|
|
if ax is None:
|
|
|
|
ax = plt.gca()
|
|
|
|
|
|
|
|
if time is None:
|
|
|
|
time = np.arange(len(signal))/sample_rate
|
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
ax.set_title(title)
|
|
|
|
ax.set_xlabel(xlabel + (" ["+time_unit+"]" if time_unit else "" ))
|
|
|
|
ax.set_ylabel(ylabel)
|
2022-10-31 18:16:03 +01:00
|
|
|
|
|
|
|
ax.plot(time, signal, **kwargs)
|
|
|
|
|
|
|
|
return ax
|
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
def plot_combined_spectrum(
|
|
|
|
spectrum, freqs,
|
|
|
|
fig=None, gs=None,
|
|
|
|
spectrum_ax=None, phase_ax=None,
|
|
|
|
spectrum_kwargs={}, phase_kwargs={},
|
|
|
|
**shared_kwargs
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Plot both the frequencies and phase in one figure.
|
|
|
|
"""
|
|
|
|
|
|
|
|
ax1, ax2 = None, None
|
|
|
|
if spectrum_ax is not None or phase_ax is not None:
|
|
|
|
if spectrum_ax is not None:
|
|
|
|
ax1 = spectrum_ax
|
|
|
|
if phase_ax is not None:
|
|
|
|
ax2 = phase_ax
|
|
|
|
|
|
|
|
else:
|
|
|
|
# configure plotting layout
|
|
|
|
if fig is None:
|
|
|
|
fig = plt.figure(figsize=(8, 16))
|
|
|
|
|
|
|
|
if gs is None:
|
|
|
|
gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0)
|
|
|
|
|
|
|
|
ax1 = fig.add_subplot(gs[:-1, -1])
|
|
|
|
ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1)
|
2022-10-31 18:16:03 +01:00
|
|
|
|
|
|
|
axes = np.array([ax1, ax2])
|
|
|
|
|
|
|
|
# plot the spectrum
|
2022-11-02 16:27:57 +01:00
|
|
|
plot_spectrum(spectrum, freqs, ax=ax1, **{ **shared_kwargs, **spectrum_kwargs})
|
2022-10-31 18:16:03 +01:00
|
|
|
|
|
|
|
# plot the phase
|
2022-11-02 16:27:57 +01:00
|
|
|
plot_phase(spectrum, freqs, ax=ax2, **{ **shared_kwargs, **phase_kwargs})
|
2022-10-31 18:16:03 +01:00
|
|
|
|
|
|
|
ax1.xaxis.tick_top()
|
|
|
|
[label.set_visible(False) for label in ax1.get_xticklabels()]
|
|
|
|
|
|
|
|
return fig, axes
|
|
|
|
|
2022-11-02 19:05:53 +01:00
|
|
|
def plot_signal_and_spectrum(
|
|
|
|
signal, sample_rate=1, title=None,
|
|
|
|
signal_kwargs={}, ft_kwargs={},
|
|
|
|
spectrum_kwargs={}, phase_kwargs={'major_divider':1, 'minor_divider': 6},
|
|
|
|
**phase_spectrum_kwargs
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Create a figure showing both the signal and the combined spectrum.
|
|
|
|
"""
|
|
|
|
fig = plt.figure(figsize=(16, 4))
|
|
|
|
|
|
|
|
if title:
|
|
|
|
fig.suptitle(title)
|
|
|
|
|
|
|
|
# setup plot layout
|
|
|
|
gs0 = gridspec.GridSpec(1, 2, figure=fig)
|
|
|
|
gs00 = gs0[0].subgridspec(1, 1)
|
|
|
|
gs01 = gs0[1].subgridspec(2, 1, height_ratios=[3,1], hspace=0)
|
|
|
|
|
|
|
|
# plot the signal
|
|
|
|
ax1 = fig.add_subplot(gs00[0, 0])
|
|
|
|
plot_signal(signal, sample_rate, ax=ax1, **signal_kwargs)
|
|
|
|
|
|
|
|
# plot spectrum
|
|
|
|
signal_fft, freqs = ft_spectrum(signal, sample_rate, **ft_kwargs)
|
|
|
|
_, (ax2, ax3) = plot_combined_spectrum(
|
|
|
|
signal_fft, freqs,
|
|
|
|
fig=fig, gs=gs01,
|
|
|
|
spectrum_kwargs=spectrum_kwargs, phase_kwargs=phase_kwargs,
|
|
|
|
**phase_spectrum_kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
# return the axes
|
|
|
|
axes = np.array([ax1, ax2, ax3])
|
|
|
|
|
|
|
|
return fig, axes
|
|
|
|
|
2022-11-02 16:27:57 +01:00
|
|
|
def multiple_formatter(denominator=2, number=np.pi, latex='\pi'):
|
|
|
|
"""
|
|
|
|
From https://stackoverflow.com/a/53586826
|
|
|
|
"""
|
|
|
|
def gcd(a, b):
|
|
|
|
while b:
|
|
|
|
a, b = b, a%b
|
|
|
|
return a
|
|
|
|
|
|
|
|
def _multiple_formatter(x, pos):
|
|
|
|
den = denominator
|
|
|
|
num = np.int(np.rint(den*x/number))
|
|
|
|
com = gcd(num,den)
|
|
|
|
(num,den) = (int(num/com),int(den/com))
|
|
|
|
if den==1:
|
|
|
|
if num==0:
|
|
|
|
return r'$0$'
|
|
|
|
if num==1:
|
|
|
|
return r'$%s$'%latex
|
|
|
|
elif num==-1:
|
|
|
|
return r'$-%s$'%latex
|
|
|
|
else:
|
|
|
|
return r'$%s%s$'%(num,latex)
|
|
|
|
else:
|
|
|
|
if num==1:
|
|
|
|
return r'$\frac{%s}{%s}$'%(latex,den)
|
|
|
|
elif num==-1:
|
|
|
|
return r'$\frac{-%s}{%s}$'%(latex,den)
|
|
|
|
else:
|
|
|
|
return r'$\frac{%s%s}{%s}$'%(num,latex,den)
|
|
|
|
return _multiple_formatter
|
|
|
|
|
|
|
|
class Multiple:
|
|
|
|
"""
|
|
|
|
From https://stackoverflow.com/a/53586826
|
|
|
|
"""
|
|
|
|
def __init__(self, denominator=2, number=np.pi, latex='\pi'):
|
|
|
|
self.denominator = denominator
|
|
|
|
self.number = number
|
|
|
|
self.latex = latex
|
|
|
|
|
|
|
|
def locator(self):
|
|
|
|
return plt.MultipleLocator(self.number / self.denominator)
|
|
|
|
|
|
|
|
def formatter(self):
|
|
|
|
return plt.FuncFormatter(multiple_formatter(self.denominator, self.number, self.latex))
|