-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlinear.py
More file actions
251 lines (219 loc) · 7.73 KB
/
Copy pathlinear.py
File metadata and controls
251 lines (219 loc) · 7.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""Linear plot types: single scatter+fit and multi-line series.
This module provides two functions:
* ``linear_plot`` — scatter plot with optional linear-regression overlay.
* ``multi_line_plot`` — overlay multiple line/scatter series on shared axes.
Both functions return a ``plt.Figure`` that can be further customised or
passed directly to ``savefig()``.
Example
-------
>>> import numpy as np
>>> from academic_plot import linear_plot, multi_line_plot, savefig
>>> x, y = np.linspace(0, 10, 30), np.random.normal(0, 1, 30)
>>> fig = linear_plot(x, y, xlabel="Time (s)", ylabel="Voltage (V)")
>>> savefig(fig, "linear_demo")
"""
from __future__ import annotations
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
from .style import COLORS, MARKERS, PALETTE, FIGSIZES, GRID_ALPHA, GRID_LINEWIDTH, GRID_COLOR, Z_ORDER, savefig, apply_grid
from .utils import validate_arrays
def linear_plot(
x: np.ndarray,
y: np.ndarray,
*,
xlabel: str = "x",
ylabel: str = "y",
title: str | None = None,
label: str | None = None,
color: str | None = None,
marker: str | None = None,
marker_size: float = 12,
marker_edgecolor: str = "white",
marker_edgewidth: float = 0.3,
fit_color: str | None = None,
fit_linewidth: float = 1.2,
show_fit: bool = True,
show_grid: bool = True,
grid_alpha: float = GRID_ALPHA,
grid_linewidth: float = GRID_LINEWIDTH,
grid_color: str = GRID_COLOR,
legend_loc: str = "best",
figsize: tuple[float, float] = FIGSIZES["single"],
) -> plt.Figure:
"""Single x-y scatter plot with an optional linear-regression fit line.
Parameters
----------
x, y : np.ndarray
1-D data arrays of the same length.
xlabel : str
Label for the x-axis.
ylabel : str
Label for the y-axis.
title : str or None
Subplot title. No title is shown if ``None``.
label : str or None
Legend label for the data points. Defaults to ``"Data"``.
color : str or None
Hex colour for the scatter markers. Falls back to ``COLORS["blue"]``.
marker : str or None
Matplotlib marker character (e.g. ``"o"``, ``"s"``, ``"D"``).
Falls back to ``MARKERS[0]`` (``"o"``).
marker_size : float
Marker area (points²).
marker_edgecolor : str
Colour of the marker border.
marker_edgewidth : float
Width of the marker border.
fit_color : str or None
Colour of the regression line. Falls back to ``COLORS["red"]``.
fit_linewidth : float
Width of the regression line.
show_fit : bool
If ``True``, draw a least-squares linear-regression line.
show_grid : bool
If ``True``, show a light background grid.
grid_alpha : float
Grid line opacity (0 = invisible, 1 = opaque).
grid_linewidth : float
Grid line thickness.
grid_color : str
Grid line colour.
legend_loc : str
Matplotlib legend location string (e.g. ``"best"``, ``"upper left"``).
figsize : tuple[float, float]
Figure size in inches (width, height).
Returns
-------
plt.Figure
The matplotlib Figure object.
"""
validate_arrays(x, y, names=["x", "y"])
fig, ax = plt.subplots(figsize=figsize)
# Resolve style defaults
c = color or COLORS["blue"]
m = marker or MARKERS[0]
fc = fit_color or COLORS["red"]
# Scatter the raw data
ax.scatter(
x, y, s=marker_size, color=c,
edgecolors=marker_edgecolor, linewidths=marker_edgewidth,
marker=m, zorder=Z_ORDER["data"], label=label or "Data",
)
# Optional linear-regression overlay
if show_fit:
coef = np.polyfit(x, y, 1) # [slope, intercept]
x_line = np.linspace(x.min(), x.max(), 300)
y_line = np.polyval(coef, x_line)
ax.plot(
x_line, y_line, color=fc, linewidth=fit_linewidth,
label=f"y = {coef[0]:.3f}x + {coef[1]:.3f}",
)
# Axis labels and title
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if title:
ax.set_title(title)
# Grid configuration
apply_grid(ax, show=show_grid, alpha=grid_alpha, linewidth=grid_linewidth, color=grid_color)
ax.minorticks_on()
ax.legend(loc=legend_loc)
fig.tight_layout(pad=0.3)
return fig
def multi_line_plot(
datasets: Sequence[dict],
*,
xlabel: str = "x",
ylabel: str = "y",
title: str | None = None,
show_grid: bool = True,
grid_alpha: float = GRID_ALPHA,
grid_linewidth: float = GRID_LINEWIDTH,
grid_color: str = GRID_COLOR,
linewidth: float = 1.2,
marker_size: float = 3.5,
marker_facecolor: str = "white",
marker_edgewidth: float = 0.8,
legend_loc: str = "best",
figsize: tuple[float, float] = FIGSIZES["wide"],
) -> plt.Figure:
"""Plot multiple data series on shared axes.
Each element of *datasets* is a dict with keys:
============ ============ =============================================
Key Required Description
============ ============ =============================================
``x`` Yes 1-D array for the x values.
``y`` Yes 1-D array for the y values.
``label`` No Legend label (default ``"Series N"``).
``color`` No Hex colour string (auto-cycled if omitted).
``marker`` No Marker character (auto-cycled if omitted).
``linestyle`` No Line style string (default ``"-"``).
============ ============ =============================================
Parameters
----------
datasets : sequence of dict
Each dict provides ``x`` and ``y`` arrays plus optional style keys.
xlabel : str
Label for the x-axis.
ylabel : str
Label for the y-axis.
title : str or None
Subplot title.
show_grid : bool
Show a light background grid.
grid_alpha : float
Grid line opacity.
grid_linewidth : float
Grid line thickness.
grid_color : str
Grid line colour.
linewidth : float
Line width applied to all series.
marker_size : float
Marker size applied to all series.
marker_facecolor : str
Fill colour inside markers.
marker_edgewidth : float
Marker border width.
legend_loc : str
Matplotlib legend location.
figsize : tuple[float, float]
Figure size in inches.
Returns
-------
plt.Figure
The matplotlib Figure object.
Example
-------
>>> datasets = [
... {"x": t, "y": y1, "label": "Series A", "color": "#2171B5"},
... {"x": t, "y": y2, "label": "Series B", "color": "#CB181D"},
... ]
>>> fig = multi_line_plot(datasets, xlabel="Time (s)", ylabel="Value")
"""
fig, ax = plt.subplots(figsize=figsize)
for i, ds in enumerate(datasets):
validate_arrays(ds["x"], ds["y"], names=[f"datasets[{i}]['x']", f"datasets[{i}]['y']"])
# Auto-cycle colour and marker when not specified
c = ds.get("color", PALETTE[i % len(PALETTE)])
m = ds.get("marker", MARKERS[i % len(MARKERS)])
ls = ds.get("linestyle", "-")
lbl = ds.get("label", f"Series {i + 1}")
ax.plot(
ds["x"], ds["y"], color=c, marker=m, markersize=marker_size,
linewidth=linewidth, linestyle=ls,
markerfacecolor=marker_facecolor,
markeredgewidth=marker_edgewidth,
markeredgecolor=c, label=lbl,
)
# Axis labels and title
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if title:
ax.set_title(title)
# Grid
apply_grid(ax, show=show_grid, alpha=grid_alpha, linewidth=grid_linewidth, color=grid_color)
ax.minorticks_on()
ax.legend(loc=legend_loc)
fig.tight_layout(pad=0.3)
return fig