Update fourier plotting lib

This commit is contained in:
Eric Teunis de Boone 2022-11-02 16:27:57 +01:00
parent e90169464c
commit 87d12748e4

View file

@ -1,11 +1,17 @@
""" """
Functions to simplify plotting Functions to simplify plotting of fourier spectra
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import numpy as np import numpy as np
def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_amplitude=None, ax=None, freq_unit="Hz", freq_scaler=1): def plot_spectrum(
spectrum, freqs,
plot_complex=False, plot_power=False, plot_amplitude=None,
freq_unit="Hz", freq_scaler=1,
title='Spectrum', xlabel='Frequency', ax=None,
**plot_kwargs
):
""" Plot a signal's spectrum on an Axis object""" """ Plot a signal's spectrum on an Axis object"""
plot_amplitude = plot_amplitude or (not plot_power and not plot_complex) plot_amplitude = plot_amplitude or (not plot_power and not plot_complex)
alpha = 1 alpha = 1
@ -13,8 +19,8 @@ def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_a
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
ax.set_title("Spectrum") ax.set_title(title)
ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" )) ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
ylabel = "" ylabel = ""
if plot_amplitude or plot_complex: if plot_amplitude or plot_complex:
ylabel = "Amplitude" ylabel = "Amplitude"
@ -26,70 +32,156 @@ def plot_spectrum( spectrum, freqs, plot_complex=False, plot_power=False, plot_a
if plot_complex: if plot_complex:
alpha = 0.5 alpha = 0.5
ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha) ax.plot(freqs/freq_scaler, np.real(spectrum), '.-', label='Real', alpha=alpha, **plot_kwargs)
ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha) ax.plot(freqs/freq_scaler, np.imag(spectrum), '.-', label='Imag', alpha=alpha, **plot_kwargs)
if plot_power: if plot_power:
ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha) ax.plot(freqs/freq_scaler, np.abs(spectrum)**2, '.-', label='Power', alpha=alpha, **plot_kwargs)
if plot_amplitude: if plot_amplitude:
ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha) ax.plot(freqs/freq_scaler, np.abs(spectrum), '.-', label='Abs', alpha=alpha, **plot_kwargs)
ax.legend() ax.legend()
return ax return ax
def plot_phase( spectrum, freqs, ylim_epsilon=0.5, ax=None, freq_unit="Hz", freq_scaler=1): def plot_phase(
spectrum, freqs,
ylim_epsilon=0.5, ax=None, grid=True,
freq_unit="Hz", freq_scaler=1, xlabel='Frequency',
**plot_kwargs
):
"""
Plot the phase of spectrum
"""
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
ax.set_ylabel("Phase") ax.set_ylabel("Phase")
ax.set_xlabel("f" + (" ["+freq_unit+"]" if freq_unit else "" )) ax.set_xlabel(xlabel + (" ["+freq_unit+"]" if freq_unit else "" ))
ax.grid(grid)
ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-') ax.plot(freqs/freq_scaler, np.angle(spectrum), '.-', **plot_kwargs)
ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon) ax.set_ylim(-1*np.pi - ylim_epsilon, np.pi + ylim_epsilon)
major_tick = Multiple(2, np.pi, '\pi')
minor_tick = Multiple(12, np.pi, '\pi')
ax.yaxis.set_major_locator(major_tick.locator())
ax.yaxis.set_major_formatter(major_tick.formatter())
ax.yaxis.set_minor_locator(minor_tick.locator())
return ax return ax
def plot_signal( signal, sample_rate = 1, ax=None, time=None, time_unit="s", **kwargs): def plot_signal(
signal, sample_rate = 1,
time=None, ax=None,
title='Signal', ylabel='Amplitude',
time_unit="s", xlabel='Time',
**kwargs):
"""
Plot the signal in the time domain
"""
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
if time is None: if time is None:
time = np.arange(len(signal))/sample_rate time = np.arange(len(signal))/sample_rate
ax.set_title("Signal") ax.set_title(title)
ax.set_xlabel("t" + (" ["+time_unit+"]" if time_unit else "" )) ax.set_xlabel(xlabel + (" ["+time_unit+"]" if time_unit else "" ))
ax.set_ylabel("A(t)") ax.set_ylabel(ylabel)
ax.plot(time, signal, **kwargs) ax.plot(time, signal, **kwargs)
return ax return ax
def plot_combined_spectrum(spectrum, freqs, def plot_combined_spectrum(
spectrum_kwargs={}, fig=None, gs=None, freq_scaler=1, freq_unit="Hz"): spectrum, freqs,
"""Plot both the frequencies and phase in one figure.""" fig=None, gs=None,
spectrum_ax=None, phase_ax=None,
spectrum_kwargs={}, phase_kwargs={},
**shared_kwargs
):
"""
Plot both the frequencies and phase in one figure.
"""
# configure plotting layout ax1, ax2 = None, None
if fig is None: if spectrum_ax is not None or phase_ax is not None:
fig = plt.figure(figsize=(8, 16)) if spectrum_ax is not None:
ax1 = spectrum_ax
if phase_ax is not None:
ax2 = phase_ax
if gs is None: else:
gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0) # configure plotting layout
if fig is None:
fig = plt.figure(figsize=(8, 16))
ax1 = fig.add_subplot(gs[:-1, -1]) if gs is None:
ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1) gs = gridspec.GridSpec(2, 1, figure=fig, height_ratios=[3,1], hspace=0)
ax1 = fig.add_subplot(gs[:-1, -1])
ax2 = fig.add_subplot(gs[-1, -1], sharex=ax1)
axes = np.array([ax1, ax2]) axes = np.array([ax1, ax2])
# plot the spectrum # plot the spectrum
plot_spectrum(spectrum, freqs, ax=ax1, freq_scaler=freq_scaler, freq_unit=freq_unit, **spectrum_kwargs) plot_spectrum(spectrum, freqs, ax=ax1, **{ **shared_kwargs, **spectrum_kwargs})
# plot the phase # plot the phase
plot_phase(spectrum, freqs, ax=ax2, freq_scaler=freq_scaler, freq_unit=freq_unit) plot_phase(spectrum, freqs, ax=ax2, **{ **shared_kwargs, **phase_kwargs})
ax1.xaxis.tick_top() ax1.xaxis.tick_top()
[label.set_visible(False) for label in ax1.get_xticklabels()] [label.set_visible(False) for label in ax1.get_xticklabels()]
return fig, axes return fig, axes
def multiple_formatter(denominator=2, number=np.pi, latex='\pi'):
"""
From https://stackoverflow.com/a/53586826
"""
def gcd(a, b):
while b:
a, b = b, a%b
return a
def _multiple_formatter(x, pos):
den = denominator
num = np.int(np.rint(den*x/number))
com = gcd(num,den)
(num,den) = (int(num/com),int(den/com))
if den==1:
if num==0:
return r'$0$'
if num==1:
return r'$%s$'%latex
elif num==-1:
return r'$-%s$'%latex
else:
return r'$%s%s$'%(num,latex)
else:
if num==1:
return r'$\frac{%s}{%s}$'%(latex,den)
elif num==-1:
return r'$\frac{-%s}{%s}$'%(latex,den)
else:
return r'$\frac{%s%s}{%s}$'%(num,latex,den)
return _multiple_formatter
class Multiple:
"""
From https://stackoverflow.com/a/53586826
"""
def __init__(self, denominator=2, number=np.pi, latex='\pi'):
self.denominator = denominator
self.number = number
self.latex = latex
def locator(self):
return plt.MultipleLocator(self.number / self.denominator)
def formatter(self):
return plt.FuncFormatter(multiple_formatter(self.denominator, self.number, self.latex))