Skip to content

Commit cc344bf

Browse files
authored
Merge pull request #78 from jtgasparik/plot_norm_deriv
Function to plot the nomalized derivative
2 parents 7750859 + 6bf74bc commit cc344bf

5 files changed

Lines changed: 76 additions & 1 deletion

File tree

pysp2/util/normalized_derivative_method.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,47 @@ def central_difference(S, num_records=None, normalize=True):
7373

7474
return xr.Dataset(dSdt)
7575

76+
77+
def plot_normalized_derivative(ds, record_no, chn=0):
78+
"""
79+
Plots the normalized derivative of the scattering signal for a given record_no and channel.
80+
81+
Parameters
82+
----------
83+
ds: xarray Dataset
84+
The dataset containing the normalized derivative to plot.
85+
record_no: int
86+
The record number to plot.
87+
chn: int
88+
The channel number to plot (0 or 4).
89+
Returns
90+
-------
91+
ax: matplotlib Axes
92+
The axes object containing the plot.
93+
"""
94+
import matplotlib.pyplot as plt
95+
96+
if chn not in [0, 4]:
97+
raise ValueError("Channel number must be 0 or 4.")
98+
99+
spectra = ds.isel(event_index=record_no)
100+
time = spectra['time'].values
101+
inp_data = {}
102+
inp_data['time'] = xr.DataArray(np.array(time[np.newaxis]),
103+
dims=['time'])
104+
inp_data['Data_ch' + str(chn)] = xr.DataArray(
105+
spectra['Data_ch' + str(chn)].values[np.newaxis, :],
106+
dims=['time', 'bins'])
107+
inp_data = xr.Dataset(inp_data)
108+
bins = np.linspace(0, 100, 100)
109+
110+
ch_name = f'Data_ch{chn}'
111+
plt.figure(figsize=(10, 6))
112+
ax = plt.gca()
113+
inp_data[ch_name].plot(ax=ax)
114+
ax.set_title(f'Normalized Derivative of Scattering Signal - Channel {chn} Record {record_no}')
115+
ax.set_xlabel('Time (s)')
116+
ax.set_ylabel('Normalized Derivative')
117+
plt.grid()
118+
119+
return ax
52.2 KB
Loading

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
import matplotlib.pyplot as plt
3+
4+
@pytest.fixture(autouse=True)
5+
def close_figures():
6+
yield
7+
plt.close("all")

tests/test_ndm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ def test_central_difference():
2121
np.testing.assert_almost_equal(dSdt_norm['Data_ch4'].isel(event_index=5876, time=99).item(),
2222
7.166666666e7/-30152, decimal=2)
2323
np.testing.assert_almost_equal(dSdt_norm['Data_ch4'].isel(event_index=5876, time=19).item(),
24-
1.5e7/-30132, decimal=2)
24+
1.5e7/-30132, decimal=2)
25+

tests/test_vis.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import matplotlib
2+
import pytest
3+
import xarray as xr
4+
import numpy as np
5+
6+
import pysp2
7+
from pysp2.util.normalized_derivative_method import plot_normalized_derivative
8+
9+
matplotlib.use("Agg")
10+
11+
@pytest.mark.mpl_image_compare(tolerance=10)
12+
def test_plot_normalized_derivative():
13+
14+
my_sp2b = pysp2.io.read_sp2(pysp2.testing.EXAMPLE_SP2B)
15+
my_ini = pysp2.io.read_config(pysp2.testing.EXAMPLE_INI)
16+
my_binary = pysp2.util.gaussian_fit(my_sp2b, my_ini, parallel=False)
17+
dSdt_norm = pysp2.util.central_difference(my_binary, normalize=True)
18+
19+
# Test the plotting function for channel 0 and record number 2
20+
ax = plot_normalized_derivative(dSdt_norm, record_no=304, chn=0)
21+
fig = ax.figure
22+
23+
return fig

0 commit comments

Comments
 (0)