Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from maxplotlib.canvas.canvas import Canvas
from maxplotlib.canvas.canvas import Canvas, SubplotSpacing

__all__ = ["Canvas"]
__all__ = ["Canvas", "SubplotSpacing"]
49 changes: 45 additions & 4 deletions src/maxplotlib/canvas/canvas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
from typing import Dict
from dataclasses import dataclass
from typing import Mapping

import matplotlib.patches as patches
import matplotlib.pyplot as plt
Expand All @@ -19,6 +20,17 @@
from maxplotlib.utils.options import Backends


@dataclass(frozen=True)
class SubplotSpacing:
"""Typed spacing configuration for subplot grids."""

wspace: float = 0.08
hspace: float = 0.1

def to_gridspec_kw(self) -> dict[str, float]:
return {"wspace": self.wspace, "hspace": self.hspace}


def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None):
"""
Plot all nodes and paths on the provided axis using Matplotlib.
Expand Down Expand Up @@ -167,7 +179,8 @@ def __init__(
dpi: int = 300,
width: str = "5cm",
ratio: str = "golden", # TODO Add literal
gridspec_kw: Dict = {"wspace": 0.08, "hspace": 0.1},
subplot_spacing: SubplotSpacing | None = None,
gridspec_kw: Mapping[str, float] | None = None,
):
"""
Initialize the Canvas class for multiple subplots.
Expand All @@ -183,7 +196,10 @@ def __init__(
dpi (int): DPI for the figure. Default is 300.
width (str): Width of the figure. Default is "17cm".
ratio (str): Aspect ratio. Default is "golden".
gridspec_kw (dict): Gridspec keyword arguments. Default is {"wspace": 0.08, "hspace": 0.1}.
subplot_spacing (SubplotSpacing): Typed subplot spacing.
Default is SubplotSpacing(wspace=0.08, hspace=0.1).
gridspec_kw (Mapping[str, float]): Optional matplotlib gridspec kwargs.
Kept for compatibility with existing code.
"""

self._nrows = nrows
Expand All @@ -196,7 +212,14 @@ def __init__(
self._dpi = dpi
self._width = width
self._ratio = ratio
self._gridspec_kw = gridspec_kw
if subplot_spacing is not None and gridspec_kw is not None:
raise ValueError("Pass either subplot_spacing or gridspec_kw, not both.")
if subplot_spacing is None and gridspec_kw is None:
subplot_spacing = SubplotSpacing()
if subplot_spacing is not None:
self._gridspec_kw = subplot_spacing.to_gridspec_kw()
else:
self._gridspec_kw = dict(gridspec_kw)
self._plotted = False
self._matplotlib_fig = None
self._matplotlib_axes = None
Expand All @@ -221,6 +244,8 @@ def subplots(
nrows: int = 1,
ncols: int = 1,
squeeze: bool = True,
wspace: float | None = None,
hspace: float | None = None,
**canvas_kwargs,
):
"""
Expand All @@ -231,6 +256,8 @@ def subplots(
nrows, ncols (int): Grid dimensions.
squeeze (bool): If True, return a single subplot instead of a 1-element
list when the grid is 1×1 or when one dimension is 1.
wspace, hspace (float): Convenience subplot spacing arguments.
These map to matplotlib gridspec spacing values.
**canvas_kwargs: Forwarded to the Canvas constructor.

Returns:
Expand All @@ -243,6 +270,19 @@ def subplots(
>>> canvas, (ax1, ax2) = Canvas.subplots(ncols=2)
>>> canvas, axes = Canvas.subplots(nrows=2, ncols=2) # axes[row][col]
"""
spacing_given = wspace is not None or hspace is not None
if spacing_given and (
"subplot_spacing" in canvas_kwargs or "gridspec_kw" in canvas_kwargs
):
raise ValueError(
"Use either wspace/hspace or subplot_spacing/gridspec_kw, not both."
)
if spacing_given:
canvas_kwargs["subplot_spacing"] = SubplotSpacing(
wspace=0.08 if wspace is None else wspace,
hspace=0.1 if hspace is None else hspace,
)

canvas = cls(nrows=nrows, ncols=ncols, **canvas_kwargs)
axes = [
[canvas.add_subplot(row=r, col=c) for c in range(ncols)]
Expand Down Expand Up @@ -829,6 +869,7 @@ def plot_matplotlib(
figsize=(fig_width, fig_height),
squeeze=False,
dpi=self.dpi,
gridspec_kw=self._gridspec_kw,
)

for (row, col), subplot in self._subplot_dict.items():
Expand Down
193 changes: 193 additions & 0 deletions src/maxplotlib/tests/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,198 @@ def test_canvas_plot_tikzfigure_vertical_not_supported():
assert "nrows > 1" in str(exc_info.value)


def test_canvas_matplotlib_gridspec_kw_affects_row_spacing():
"""Test that hspace changes the vertical spacing between rows."""
import matplotlib.pyplot as plt
import numpy as np

from maxplotlib import Canvas

x = np.linspace(0, 1, 5)

tight_canvas, tight_axes = Canvas.subplots(
nrows=2,
ncols=1,
width="10cm",
ratio=0.7,
gridspec_kw={"hspace": 0.02, "wspace": 0.08},
)
for ax in tight_axes:
ax.plot(x, x)
tight_fig, tight_matplotlib_axes = tight_canvas.plot()
tight_gap = (
tight_matplotlib_axes[0, 0].get_position().y0
- tight_matplotlib_axes[1, 0].get_position().y1
)

loose_canvas, loose_axes = Canvas.subplots(
nrows=2,
ncols=1,
width="10cm",
ratio=0.7,
gridspec_kw={"hspace": 0.5, "wspace": 0.08},
)
for ax in loose_axes:
ax.plot(x, x)
loose_fig, loose_matplotlib_axes = loose_canvas.plot()
loose_gap = (
loose_matplotlib_axes[0, 0].get_position().y0
- loose_matplotlib_axes[1, 0].get_position().y1
)

assert loose_gap > tight_gap
plt.close(tight_fig)
plt.close(loose_fig)


def test_canvas_matplotlib_gridspec_kw_affects_2x2_line_spacing():
"""Test that wspace/hspace change spacing for 2×2 line subplot grids."""
import matplotlib.pyplot as plt
import numpy as np

from maxplotlib import Canvas

x = np.linspace(0, 1, 20)

tight_canvas, tight_axes = Canvas.subplots(
nrows=2,
ncols=2,
width="12cm",
ratio=0.7,
hspace=0.03,
wspace=0.03,
)
idx = 0
for row_axes in tight_axes:
for ax in row_axes:
ax.plot(x, (idx + 1) * x)
idx += 1
tight_fig, tight_matplotlib_axes = tight_canvas.plot(backend="matplotlib")
tight_hgap = (
tight_matplotlib_axes[0, 1].get_position().x0
- tight_matplotlib_axes[0, 0].get_position().x1
)
tight_vgap = (
tight_matplotlib_axes[0, 0].get_position().y0
- tight_matplotlib_axes[1, 0].get_position().y1
)

loose_canvas, loose_axes = Canvas.subplots(
nrows=2,
ncols=2,
width="12cm",
ratio=0.7,
hspace=0.45,
wspace=0.45,
)
idx = 0
for row_axes in loose_axes:
for ax in row_axes:
ax.plot(x, (idx + 1) * x)
idx += 1
loose_fig, loose_matplotlib_axes = loose_canvas.plot(backend="matplotlib")
loose_hgap = (
loose_matplotlib_axes[0, 1].get_position().x0
- loose_matplotlib_axes[0, 0].get_position().x1
)
loose_vgap = (
loose_matplotlib_axes[0, 0].get_position().y0
- loose_matplotlib_axes[1, 0].get_position().y1
)

assert loose_hgap > tight_hgap
assert loose_vgap > tight_vgap
plt.close(tight_fig)
plt.close(loose_fig)


def test_canvas_matplotlib_gridspec_kw_affects_2x2_imshow_spacing():
"""Test spacing control also works for 2×2 color (imshow) subplot grids."""
import matplotlib.pyplot as plt
import numpy as np

from maxplotlib import Canvas

data = np.arange(100).reshape(10, 10)

tight_canvas, tight_axes = Canvas.subplots(
nrows=2,
ncols=2,
width="12cm",
ratio=0.8,
hspace=0.03,
wspace=0.03,
)
idx = 0
for row_axes in tight_axes:
for ax in row_axes:
ax.add_imshow(data + idx, cmap="viridis")
ax.set_title(f"Heatmap {idx + 1}")
idx += 1
tight_fig, tight_matplotlib_axes = tight_canvas.plot(backend="matplotlib")
tight_hgap = (
tight_matplotlib_axes[0, 1].get_position().x0
- tight_matplotlib_axes[0, 0].get_position().x1
)
tight_vgap = (
tight_matplotlib_axes[0, 0].get_position().y0
- tight_matplotlib_axes[1, 0].get_position().y1
)

loose_canvas, loose_axes = Canvas.subplots(
nrows=2,
ncols=2,
width="12cm",
ratio=0.8,
hspace=0.45,
wspace=0.45,
)
idx = 0
for row_axes in loose_axes:
for ax in row_axes:
ax.add_imshow(data + idx, cmap="viridis")
ax.set_title(f"Heatmap {idx + 1}")
idx += 1
loose_fig, loose_matplotlib_axes = loose_canvas.plot(backend="matplotlib")
loose_hgap = (
loose_matplotlib_axes[0, 1].get_position().x0
- loose_matplotlib_axes[0, 0].get_position().x1
)
loose_vgap = (
loose_matplotlib_axes[0, 0].get_position().y0
- loose_matplotlib_axes[1, 0].get_position().y1
)

assert loose_hgap > tight_hgap
assert loose_vgap > tight_vgap
plt.close(tight_fig)
plt.close(loose_fig)


def test_canvas_spacing_and_gridspec_kw_are_mutually_exclusive():
import pytest

from maxplotlib import Canvas, SubplotSpacing

with pytest.raises(ValueError):
Canvas(
subplot_spacing=SubplotSpacing(wspace=0.2, hspace=0.2),
gridspec_kw={"wspace": 0.3, "hspace": 0.3},
)


def test_canvas_subplots_spacing_args_and_explicit_spacing_are_mutually_exclusive():
import pytest

from maxplotlib import Canvas, SubplotSpacing

with pytest.raises(ValueError):
Canvas.subplots(
wspace=0.2,
hspace=0.2,
subplot_spacing=SubplotSpacing(wspace=0.3, hspace=0.3),
)


if __name__ == "__main__":
test()
Loading
Loading