-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathheatmap.py
More file actions
148 lines (127 loc) · 4.52 KB
/
Copy pathheatmap.py
File metadata and controls
148 lines (127 loc) · 4.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Annotated heatmap for correlation matrices, confusion matrices, etc.
One function:
* ``heatmap`` — 2-D colour-coded matrix with optional numeric annotations.
Example
-------
>>> import numpy as np
>>> from academic_plot import heatmap, savefig
>>> C = np.corrcoef(np.random.randn(200, 4).T)
>>> fig = heatmap(C, xlabels=["A","B","C","D"], ylabels=["A","B","C","D"])
>>> savefig(fig, "heatmap_demo")
"""
from __future__ import annotations
from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
from .style import savefig
from .utils import validate_matrix
def heatmap(
data: np.ndarray,
*,
xlabels: Sequence[str] | None = None,
ylabels: Sequence[str] | None = None,
title: str | None = None,
cmap: str = "Blues",
annotate: bool = True,
fmt: str = ".2f",
annot_fontsize: float = 6,
vmin: float | None = None,
vmax: float | None = None,
colorbar_label: str | None = None,
colorbar_fontsize: float = 7,
colorbar_tick_labelsize: float = 6.5,
xtick_fontsize: float = 7,
ytick_fontsize: float = 7,
xtick_rotation: float = 45,
figsize: tuple[float, float] | None = None,
) -> plt.Figure:
"""Annotated heatmap (e.g. correlation matrix, confusion matrix).
Parameters
----------
data : np.ndarray
2-D array of shape ``(rows, cols)``.
xlabels : sequence of str or None
Column labels. Defaults to integer indices.
ylabels : sequence of str or None
Row labels. Defaults to integer indices.
title : str or None
Subplot title.
cmap : str
Matplotlib colormap name (e.g. ``"Blues"``, ``"RdBu_r"``,
``"viridis"``, ``"coolwarm"``).
annotate : bool
Print numeric values inside each cell.
fmt : str
Format string for cell annotations (e.g. ``".2f"``, ``".0f"``).
annot_fontsize : float
Font size of cell annotations.
vmin, vmax : float or None
Colour scale limits. ``None`` uses data min/max.
colorbar_label : str or None
Label on the colour bar.
colorbar_fontsize : float
Font size for the colour-bar label.
colorbar_tick_labelsize : float
Font size for colour-bar tick labels.
xtick_fontsize, ytick_fontsize : float
Font size for axis tick labels.
xtick_rotation : float
Rotation of column labels (degrees).
figsize : tuple[float, float] or None
Figure size. Auto-calculated from matrix dimensions if ``None``.
Returns
-------
plt.Figure
"""
validate_matrix(data)
rows, cols = data.shape
# Auto-size: scale with the matrix dimensions
if figsize is None:
figsize = (max(3.0, cols * 0.55 + 0.6),
max(2.4, rows * 0.45 + 0.6))
fig, ax = plt.subplots(figsize=figsize)
# Render the matrix as an image
im = ax.imshow(
data, cmap=cmap, aspect="auto", vmin=vmin, vmax=vmax,
interpolation="nearest",
)
# Annotate cells with numeric values
if annotate:
data_range = data.max() - data.min() + 1e-12
for i in range(rows):
for j in range(cols):
val = data[i, j]
# Choose white text on dark cells, dark text on light cells
rel = abs(val - data.min()) / data_range
text_color = "white" if rel > 0.5 else "#333"
ax.text(
j, i, format(val, fmt), ha="center", va="center",
fontsize=annot_fontsize, color=text_color,
)
# Axis ticks and labels
ax.set_xticks(np.arange(cols))
ax.set_yticks(np.arange(rows))
ax.set_xticklabels(
xlabels if xlabels is not None else list(range(cols)),
fontsize=xtick_fontsize,
)
ax.set_yticklabels(
ylabels if ylabels is not None else list(range(rows)),
fontsize=ytick_fontsize,
)
# Move x-ticks to the top (common for correlation / confusion matrices)
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
plt.setp(ax.get_xticklabels(), rotation=xtick_rotation,
ha="left", rotation_mode="anchor")
# Remove spines for a cleaner look
for spine in ax.spines.values():
spine.set_visible(False)
# Colour bar
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.ax.tick_params(labelsize=colorbar_tick_labelsize)
if colorbar_label:
cbar.set_label(colorbar_label, fontsize=colorbar_fontsize)
if title:
ax.set_title(title, pad=10, fontsize=10)
fig.tight_layout(pad=0.4)
return fig