-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathforest.py
More file actions
191 lines (172 loc) · 6.28 KB
/
Copy pathforest.py
File metadata and controls
191 lines (172 loc) · 6.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Forest plot for meta-analysis: effect sizes with confidence intervals.
One function:
* ``forest_plot`` — point estimates with CIs and optional pooled summary.
Example
-------
>>> import numpy as np
>>> from academic_plot import forest_plot, savefig
>>> studies = ["Smith 2019", "Lee 2020", "Garcia 2021"]
>>> est = np.array([0.45, 0.62, 0.38])
>>> lo = np.array([0.20, 0.40, 0.15])
>>> hi = np.array([0.70, 0.84, 0.61])
>>> fig = forest_plot(studies, est, lo, hi, summary_estimate=0.48,
... summary_ci=(0.35, 0.61))
>>> savefig(fig, "forest_demo")
"""
from __future__ import annotations
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
from .style import COLORS, PALETTE, FIGSIZES, GRID_ALPHA, GRID_LINEWIDTH, GRID_COLOR, Z_ORDER, savefig, apply_grid
from .utils import validate_arrays
def forest_plot(
labels: Sequence[str],
estimates: np.ndarray,
ci_low: np.ndarray,
ci_high: np.ndarray,
*,
xlabel: str = "Effect Size",
ylabel: str | None = None,
title: str | None = None,
reference_line: float | None = 0.0,
reference_linestyle: str = "--",
reference_linewidth: float = 0.7,
reference_color: str = "#888888",
marker: str = "s",
marker_size: float = 5,
marker_edgecolor: str = "white",
marker_edgewidth: float = 0.3,
color: str | None = None,
error_linewidth: float = 0.8,
capsize: float = 2.5,
capthick: float = 0.6,
summary_estimate: float | None = None,
summary_ci: tuple[float, float] | None = None,
summary_label: str = "Overall",
summary_color: str | None = None,
summary_marker_size: float = 7,
summary_linewidth: float = 1.8,
show_grid: bool = True,
grid_alpha: float = GRID_ALPHA,
grid_linewidth: float = GRID_LINEWIDTH,
grid_color: str = GRID_COLOR,
figsize: tuple[float, float] = FIGSIZES["forest"],
) -> plt.Figure:
"""Forest plot showing point estimates with confidence intervals.
Each study (row) is drawn as a square marker with horizontal CI
whiskers. An optional pooled summary is shown as a diamond at the
bottom.
Parameters
----------
labels : sequence of str
Study / group names (displayed on the y-axis).
estimates : np.ndarray
Point estimates (one per study).
ci_low : np.ndarray
Lower bounds of confidence intervals.
ci_high : np.ndarray
Upper bounds of confidence intervals.
xlabel : str
x-axis label (the effect-size axis).
ylabel : str or None
y-axis label. Usually omitted for forest plots.
title : str or None
Subplot title.
reference_line : float or None
x position of the vertical reference line (e.g. 0 for no effect,
0.5 for odds-ratio null). Set to ``None`` to hide.
reference_linestyle : str
Style of the reference line.
reference_linewidth : float
Width of the reference line.
reference_color : str
Colour of the reference line.
marker : str
Shape of study markers (``"s"`` = square, ``"o"`` = circle, …).
marker_size : float
Size of study markers.
marker_edgecolor : str
Border colour of study markers.
marker_edgewidth : float
Border width of study markers.
color : str or None
Colour of study markers and CI lines. Falls back to ``PALETTE[0]``.
error_linewidth : float
Thickness of CI whisker lines.
capsize : float
Length of CI whisker caps.
capthick : float
Thickness of CI whisker caps.
summary_estimate : float or None
Pooled estimate value. Displayed as a diamond if provided.
summary_ci : tuple[float, float] or None
``(low, high)`` of the pooled CI. Required if *summary_estimate*
is given.
summary_label : str
Label for the summary row.
summary_color : str or None
Colour of the summary diamond. Falls back to ``COLORS["red"]``.
summary_marker_size : float
Size of the diamond marker.
summary_linewidth : float
Thickness of the summary CI line.
show_grid : bool
Show vertical grid lines.
grid_alpha, grid_linewidth, grid_color
Grid styling.
figsize : tuple[float, float]
Figure size in inches.
Returns
-------
plt.Figure
"""
validate_arrays(estimates, ci_low, ci_high, names=["estimates", "ci_low", "ci_high"])
n = len(labels)
c = color or PALETTE[0]
sc = summary_color or COLORS["red"]
# y positions: studies ordered top-to-bottom (reversed)
y_pos = np.arange(n)[::-1]
fig, ax = plt.subplots(figsize=figsize)
# Vertical reference line (e.g. null effect)
if reference_line is not None:
ax.axvline(
reference_line, color=reference_color,
linewidth=reference_linewidth, linestyle=reference_linestyle,
zorder=Z_ORDER["background"],
)
# Individual studies: square markers + CI whiskers
ax.errorbar(
estimates, y_pos,
xerr=[estimates - ci_low, ci_high - estimates],
fmt=marker, markersize=marker_size, markerfacecolor=c,
markeredgecolor=marker_edgecolor, markeredgewidth=marker_edgewidth,
color=c, ecolor=c, elinewidth=error_linewidth, capsize=capsize,
capthick=capthick, zorder=Z_ORDER["data"],
)
# Summary estimate: diamond marker + thick CI line
if summary_estimate is not None and summary_ci is not None:
sy = -0.8 # place below the last study
ax.plot(
summary_estimate, sy, marker="D", markersize=summary_marker_size,
markerfacecolor=sc, markeredgecolor="white",
markeredgewidth=0.4, zorder=Z_ORDER["summary"],
)
ax.hlines(
sy, summary_ci[0], summary_ci[1], colors=sc,
linewidth=summary_linewidth, zorder=Z_ORDER["data"],
)
all_labels = list(labels) + [summary_label]
all_y = list(y_pos) + [sy]
else:
all_labels = list(labels)
all_y = list(y_pos)
ax.set_yticks(all_y)
ax.set_yticklabels(all_labels, fontsize=7.5)
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if title:
ax.set_title(title)
apply_grid(ax, show=show_grid, alpha=grid_alpha, linewidth=grid_linewidth, color=grid_color, axis="x")
fig.tight_layout(pad=0.3)
return fig