diff --git a/src/maxplotlib/__init__.py b/src/maxplotlib/__init__.py index 5f86d34..7b715b2 100644 --- a/src/maxplotlib/__init__.py +++ b/src/maxplotlib/__init__.py @@ -1,3 +1,3 @@ -from maxplotlib.canvas.canvas import Canvas +from maxplotlib.canvas.canvas import Canvas, SubplotSpacing -__all__ = ["Canvas"] +__all__ = ["Canvas", "SubplotSpacing"] diff --git a/src/maxplotlib/canvas/canvas.py b/src/maxplotlib/canvas/canvas.py index ba1027c..0cc7f7f 100644 --- a/src/maxplotlib/canvas/canvas.py +++ b/src/maxplotlib/canvas/canvas.py @@ -1,6 +1,7 @@ import os import re -from typing import Dict +from dataclasses import dataclass +from typing import Mapping import matplotlib.patches as patches import matplotlib.pyplot as plt @@ -19,6 +20,17 @@ from maxplotlib.utils.options import Backends +@dataclass(frozen=True) +class SubplotSpacing: + """Typed spacing configuration for subplot grids.""" + + wspace: float = 0.08 + hspace: float = 0.1 + + def to_gridspec_kw(self) -> dict[str, float]: + return {"wspace": self.wspace, "hspace": self.hspace} + + def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None): """ Plot all nodes and paths on the provided axis using Matplotlib. @@ -167,7 +179,8 @@ def __init__( dpi: int = 300, width: str = "5cm", ratio: str = "golden", # TODO Add literal - gridspec_kw: Dict = {"wspace": 0.08, "hspace": 0.1}, + subplot_spacing: SubplotSpacing | None = None, + gridspec_kw: Mapping[str, float] | None = None, ): """ Initialize the Canvas class for multiple subplots. @@ -183,7 +196,10 @@ def __init__( dpi (int): DPI for the figure. Default is 300. width (str): Width of the figure. Default is "17cm". ratio (str): Aspect ratio. Default is "golden". - gridspec_kw (dict): Gridspec keyword arguments. Default is {"wspace": 0.08, "hspace": 0.1}. + subplot_spacing (SubplotSpacing): Typed subplot spacing. + Default is SubplotSpacing(wspace=0.08, hspace=0.1). + gridspec_kw (Mapping[str, float]): Optional matplotlib gridspec kwargs. + Kept for compatibility with existing code. """ self._nrows = nrows @@ -196,7 +212,14 @@ def __init__( self._dpi = dpi self._width = width self._ratio = ratio - self._gridspec_kw = gridspec_kw + if subplot_spacing is not None and gridspec_kw is not None: + raise ValueError("Pass either subplot_spacing or gridspec_kw, not both.") + if subplot_spacing is None and gridspec_kw is None: + subplot_spacing = SubplotSpacing() + if subplot_spacing is not None: + self._gridspec_kw = subplot_spacing.to_gridspec_kw() + else: + self._gridspec_kw = dict(gridspec_kw) self._plotted = False self._matplotlib_fig = None self._matplotlib_axes = None @@ -221,6 +244,8 @@ def subplots( nrows: int = 1, ncols: int = 1, squeeze: bool = True, + wspace: float | None = None, + hspace: float | None = None, **canvas_kwargs, ): """ @@ -231,6 +256,8 @@ def subplots( nrows, ncols (int): Grid dimensions. squeeze (bool): If True, return a single subplot instead of a 1-element list when the grid is 1×1 or when one dimension is 1. + wspace, hspace (float): Convenience subplot spacing arguments. + These map to matplotlib gridspec spacing values. **canvas_kwargs: Forwarded to the Canvas constructor. Returns: @@ -243,6 +270,19 @@ def subplots( >>> canvas, (ax1, ax2) = Canvas.subplots(ncols=2) >>> canvas, axes = Canvas.subplots(nrows=2, ncols=2) # axes[row][col] """ + spacing_given = wspace is not None or hspace is not None + if spacing_given and ( + "subplot_spacing" in canvas_kwargs or "gridspec_kw" in canvas_kwargs + ): + raise ValueError( + "Use either wspace/hspace or subplot_spacing/gridspec_kw, not both." + ) + if spacing_given: + canvas_kwargs["subplot_spacing"] = SubplotSpacing( + wspace=0.08 if wspace is None else wspace, + hspace=0.1 if hspace is None else hspace, + ) + canvas = cls(nrows=nrows, ncols=ncols, **canvas_kwargs) axes = [ [canvas.add_subplot(row=r, col=c) for c in range(ncols)] @@ -829,6 +869,7 @@ def plot_matplotlib( figsize=(fig_width, fig_height), squeeze=False, dpi=self.dpi, + gridspec_kw=self._gridspec_kw, ) for (row, col), subplot in self._subplot_dict.items(): diff --git a/src/maxplotlib/tests/test_canvas.py b/src/maxplotlib/tests/test_canvas.py index b0cedde..aa48efb 100644 --- a/src/maxplotlib/tests/test_canvas.py +++ b/src/maxplotlib/tests/test_canvas.py @@ -77,5 +77,198 @@ def test_canvas_plot_tikzfigure_vertical_not_supported(): assert "nrows > 1" in str(exc_info.value) +def test_canvas_matplotlib_gridspec_kw_affects_row_spacing(): + """Test that hspace changes the vertical spacing between rows.""" + import matplotlib.pyplot as plt + import numpy as np + + from maxplotlib import Canvas + + x = np.linspace(0, 1, 5) + + tight_canvas, tight_axes = Canvas.subplots( + nrows=2, + ncols=1, + width="10cm", + ratio=0.7, + gridspec_kw={"hspace": 0.02, "wspace": 0.08}, + ) + for ax in tight_axes: + ax.plot(x, x) + tight_fig, tight_matplotlib_axes = tight_canvas.plot() + tight_gap = ( + tight_matplotlib_axes[0, 0].get_position().y0 + - tight_matplotlib_axes[1, 0].get_position().y1 + ) + + loose_canvas, loose_axes = Canvas.subplots( + nrows=2, + ncols=1, + width="10cm", + ratio=0.7, + gridspec_kw={"hspace": 0.5, "wspace": 0.08}, + ) + for ax in loose_axes: + ax.plot(x, x) + loose_fig, loose_matplotlib_axes = loose_canvas.plot() + loose_gap = ( + loose_matplotlib_axes[0, 0].get_position().y0 + - loose_matplotlib_axes[1, 0].get_position().y1 + ) + + assert loose_gap > tight_gap + plt.close(tight_fig) + plt.close(loose_fig) + + +def test_canvas_matplotlib_gridspec_kw_affects_2x2_line_spacing(): + """Test that wspace/hspace change spacing for 2×2 line subplot grids.""" + import matplotlib.pyplot as plt + import numpy as np + + from maxplotlib import Canvas + + x = np.linspace(0, 1, 20) + + tight_canvas, tight_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.7, + hspace=0.03, + wspace=0.03, + ) + idx = 0 + for row_axes in tight_axes: + for ax in row_axes: + ax.plot(x, (idx + 1) * x) + idx += 1 + tight_fig, tight_matplotlib_axes = tight_canvas.plot(backend="matplotlib") + tight_hgap = ( + tight_matplotlib_axes[0, 1].get_position().x0 + - tight_matplotlib_axes[0, 0].get_position().x1 + ) + tight_vgap = ( + tight_matplotlib_axes[0, 0].get_position().y0 + - tight_matplotlib_axes[1, 0].get_position().y1 + ) + + loose_canvas, loose_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.7, + hspace=0.45, + wspace=0.45, + ) + idx = 0 + for row_axes in loose_axes: + for ax in row_axes: + ax.plot(x, (idx + 1) * x) + idx += 1 + loose_fig, loose_matplotlib_axes = loose_canvas.plot(backend="matplotlib") + loose_hgap = ( + loose_matplotlib_axes[0, 1].get_position().x0 + - loose_matplotlib_axes[0, 0].get_position().x1 + ) + loose_vgap = ( + loose_matplotlib_axes[0, 0].get_position().y0 + - loose_matplotlib_axes[1, 0].get_position().y1 + ) + + assert loose_hgap > tight_hgap + assert loose_vgap > tight_vgap + plt.close(tight_fig) + plt.close(loose_fig) + + +def test_canvas_matplotlib_gridspec_kw_affects_2x2_imshow_spacing(): + """Test spacing control also works for 2×2 color (imshow) subplot grids.""" + import matplotlib.pyplot as plt + import numpy as np + + from maxplotlib import Canvas + + data = np.arange(100).reshape(10, 10) + + tight_canvas, tight_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.8, + hspace=0.03, + wspace=0.03, + ) + idx = 0 + for row_axes in tight_axes: + for ax in row_axes: + ax.add_imshow(data + idx, cmap="viridis") + ax.set_title(f"Heatmap {idx + 1}") + idx += 1 + tight_fig, tight_matplotlib_axes = tight_canvas.plot(backend="matplotlib") + tight_hgap = ( + tight_matplotlib_axes[0, 1].get_position().x0 + - tight_matplotlib_axes[0, 0].get_position().x1 + ) + tight_vgap = ( + tight_matplotlib_axes[0, 0].get_position().y0 + - tight_matplotlib_axes[1, 0].get_position().y1 + ) + + loose_canvas, loose_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.8, + hspace=0.45, + wspace=0.45, + ) + idx = 0 + for row_axes in loose_axes: + for ax in row_axes: + ax.add_imshow(data + idx, cmap="viridis") + ax.set_title(f"Heatmap {idx + 1}") + idx += 1 + loose_fig, loose_matplotlib_axes = loose_canvas.plot(backend="matplotlib") + loose_hgap = ( + loose_matplotlib_axes[0, 1].get_position().x0 + - loose_matplotlib_axes[0, 0].get_position().x1 + ) + loose_vgap = ( + loose_matplotlib_axes[0, 0].get_position().y0 + - loose_matplotlib_axes[1, 0].get_position().y1 + ) + + assert loose_hgap > tight_hgap + assert loose_vgap > tight_vgap + plt.close(tight_fig) + plt.close(loose_fig) + + +def test_canvas_spacing_and_gridspec_kw_are_mutually_exclusive(): + import pytest + + from maxplotlib import Canvas, SubplotSpacing + + with pytest.raises(ValueError): + Canvas( + subplot_spacing=SubplotSpacing(wspace=0.2, hspace=0.2), + gridspec_kw={"wspace": 0.3, "hspace": 0.3}, + ) + + +def test_canvas_subplots_spacing_args_and_explicit_spacing_are_mutually_exclusive(): + import pytest + + from maxplotlib import Canvas, SubplotSpacing + + with pytest.raises(ValueError): + Canvas.subplots( + wspace=0.2, + hspace=0.2, + subplot_spacing=SubplotSpacing(wspace=0.3, hspace=0.3), + ) + + if __name__ == "__main__": test() diff --git a/src/maxplotlib/tests/test_plot.py b/src/maxplotlib/tests/test_plot.py index e69de29..7667ed3 100644 --- a/src/maxplotlib/tests/test_plot.py +++ b/src/maxplotlib/tests/test_plot.py @@ -0,0 +1,104 @@ +import matplotlib.pyplot as plt +import numpy as np + +from maxplotlib import Canvas + + +def test_python_example_nxm_line_subplots_spacing_changes(): + """Python example: 2x2 line subplots honor wspace/hspace settings.""" + x = np.linspace(0, 2 * np.pi, 200) + + tight_canvas, tight_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.7, + wspace=0.05, + hspace=0.05, + ) + for i, row in enumerate(tight_axes): + for j, ax in enumerate(row): + ax.plot(x, np.sin((i + 1) * (j + 1) * x)) + tight_fig, tight_m_axes = tight_canvas.plot(backend="matplotlib") + tight_hgap = ( + tight_m_axes[0, 1].get_position().x0 - tight_m_axes[0, 0].get_position().x1 + ) + tight_vgap = ( + tight_m_axes[0, 0].get_position().y0 - tight_m_axes[1, 0].get_position().y1 + ) + + loose_canvas, loose_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.7, + wspace=0.45, + hspace=0.45, + ) + for i, row in enumerate(loose_axes): + for j, ax in enumerate(row): + ax.plot(x, np.sin((i + 1) * (j + 1) * x)) + loose_fig, loose_m_axes = loose_canvas.plot(backend="matplotlib") + loose_hgap = ( + loose_m_axes[0, 1].get_position().x0 - loose_m_axes[0, 0].get_position().x1 + ) + loose_vgap = ( + loose_m_axes[0, 0].get_position().y0 - loose_m_axes[1, 0].get_position().y1 + ) + + assert loose_hgap > tight_hgap + assert loose_vgap > tight_vgap + plt.close(tight_fig) + plt.close(loose_fig) + + +def test_python_example_nxm_color_subplots_spacing_changes(): + """Python example: 2x2 color subplots (imshow) honor wspace/hspace.""" + base = np.arange(100).reshape(10, 10) + + tight_canvas, tight_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.8, + wspace=0.05, + hspace=0.05, + ) + idx = 0 + for row in tight_axes: + for ax in row: + ax.add_imshow(base + idx, cmap="viridis") + idx += 1 + tight_fig, tight_m_axes = tight_canvas.plot(backend="matplotlib") + tight_hgap = ( + tight_m_axes[0, 1].get_position().x0 - tight_m_axes[0, 0].get_position().x1 + ) + tight_vgap = ( + tight_m_axes[0, 0].get_position().y0 - tight_m_axes[1, 0].get_position().y1 + ) + + loose_canvas, loose_axes = Canvas.subplots( + nrows=2, + ncols=2, + width="12cm", + ratio=0.8, + wspace=0.45, + hspace=0.45, + ) + idx = 0 + for row in loose_axes: + for ax in row: + ax.add_imshow(base + idx, cmap="viridis") + idx += 1 + loose_fig, loose_m_axes = loose_canvas.plot(backend="matplotlib") + loose_hgap = ( + loose_m_axes[0, 1].get_position().x0 - loose_m_axes[0, 0].get_position().x1 + ) + loose_vgap = ( + loose_m_axes[0, 0].get_position().y0 - loose_m_axes[1, 0].get_position().y1 + ) + + assert loose_hgap > tight_hgap + assert loose_vgap > tight_vgap + plt.close(tight_fig) + plt.close(loose_fig) diff --git a/tutorials/tutorial_10_matplotlib_nxm_spacing.ipynb b/tutorials/tutorial_10_matplotlib_nxm_spacing.ipynb new file mode 100644 index 0000000..e85ba12 --- /dev/null +++ b/tutorials/tutorial_10_matplotlib_nxm_spacing.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Tutorial 10 - Matplotlib NxM Subplots and Spacing\n", + "\n", + "This tutorial shows how to build **NxM subplot grids** with the matplotlib backend and control the distance between plots using `wspace=...` and `hspace=...`.\n", + "\n", + "It includes both:\n", + "- line plots\n", + "- color plots (`add_imshow`)\n", + "\n", + "and prints measured subplot gaps so you can verify spacing changes numerically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from maxplotlib import Canvas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "def measure_gaps(axes):\n", + " \"\"\"Measure one representative horizontal and vertical subplot gap.\"\"\"\n", + " horizontal_gap = axes[0, 1].get_position().x0 - axes[0, 0].get_position().x1\n", + " vertical_gap = axes[0, 0].get_position().y0 - axes[1, 0].get_position().y1\n", + " return horizontal_gap, vertical_gap" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## 1 · NxM line plots with spacing control" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "x = np.linspace(0, 2 * np.pi, 200)\n", + "\n", + "tight_canvas, tight_axes = Canvas.subplots(\n", + " nrows=2,\n", + " ncols=3,\n", + " width=\"14cm\",\n", + " ratio=0.65,\n", + " wspace=0.05,\n", + " hspace=0.08,\n", + ")\n", + "for i, row in enumerate(tight_axes):\n", + " for j, ax in enumerate(row):\n", + " ax.plot(x, np.sin((i + 1) * (j + 1) * x), label=f\"sin({(i + 1) * (j + 1)}x)\")\n", + " ax.set_title(f\"line {i},{j}\")\n", + "\n", + "tight_fig, tight_m_axes = tight_canvas.plot(backend=\"matplotlib\")\n", + "tight_fig.suptitle(\"Line plots - tight spacing\")\n", + "tight_h, tight_v = measure_gaps(tight_m_axes)\n", + "print(f\"tight line gaps: h={tight_h:.4f}, v={tight_v:.4f}\")\n", + "\n", + "loose_canvas, loose_axes = Canvas.subplots(\n", + " nrows=2,\n", + " ncols=3,\n", + " width=\"14cm\",\n", + " ratio=0.65,\n", + " wspace=0.45,\n", + " hspace=0.45,\n", + ")\n", + "for i, row in enumerate(loose_axes):\n", + " for j, ax in enumerate(row):\n", + " ax.plot(x, np.sin((i + 1) * (j + 1) * x), label=f\"sin({(i + 1) * (j + 1)}x)\")\n", + " ax.set_title(f\"line {i},{j}\")\n", + "\n", + "loose_fig, loose_m_axes = loose_canvas.plot(backend=\"matplotlib\")\n", + "loose_fig.suptitle(\"Line plots - loose spacing\")\n", + "loose_h, loose_v = measure_gaps(loose_m_axes)\n", + "print(f\"loose line gaps: h={loose_h:.4f}, v={loose_v:.4f}\")\n", + "\n", + "assert loose_h > tight_h\n", + "assert loose_v > tight_v" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## 2 · NxM color plots (`imshow`) with spacing control" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "base = np.arange(100).reshape(10, 10)\n", + "\n", + "tight_canvas, tight_axes = Canvas.subplots(\n", + " nrows=2,\n", + " ncols=3,\n", + " width=\"14cm\",\n", + " ratio=0.75,\n", + " wspace=0.05,\n", + " hspace=0.08,\n", + ")\n", + "idx = 0\n", + "for row in tight_axes:\n", + " for ax in row:\n", + " ax.add_imshow(base + idx, cmap=\"viridis\")\n", + " ax.set_title(f\"heatmap {idx}\")\n", + " idx += 1\n", + "\n", + "tight_fig, tight_m_axes = tight_canvas.plot(backend=\"matplotlib\")\n", + "tight_fig.suptitle(\"Color plots - tight spacing\")\n", + "tight_h, tight_v = measure_gaps(tight_m_axes)\n", + "print(f\"tight color gaps: h={tight_h:.4f}, v={tight_v:.4f}\")\n", + "\n", + "loose_canvas, loose_axes = Canvas.subplots(\n", + " nrows=2,\n", + " ncols=3,\n", + " width=\"14cm\",\n", + " ratio=0.75,\n", + " wspace=0.45,\n", + " hspace=0.45,\n", + ")\n", + "idx = 0\n", + "for row in loose_axes:\n", + " for ax in row:\n", + " ax.add_imshow(base + idx, cmap=\"viridis\")\n", + " ax.set_title(f\"heatmap {idx}\")\n", + " idx += 1\n", + "\n", + "loose_fig, loose_m_axes = loose_canvas.plot(backend=\"matplotlib\")\n", + "loose_fig.suptitle(\"Color plots - loose spacing\")\n", + "loose_h, loose_v = measure_gaps(loose_m_axes)\n", + "print(f\"loose color gaps: h={loose_h:.4f}, v={loose_v:.4f}\")\n", + "\n", + "assert loose_h > tight_h\n", + "assert loose_v > tight_v\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}