diff --git a/ultraplot/legend.py b/ultraplot/legend.py index 9d11ffb9e..c6c66ee22 100644 --- a/ultraplot/legend.py +++ b/ultraplot/legend.py @@ -1,10 +1,98 @@ +from matplotlib import lines as mlines from matplotlib import legend as mlegend +from matplotlib import legend_handler as mhandler +from matplotlib import patches as mpatches try: from typing import override except ImportError: from typing_extensions import override +__all__ = ["Legend", "LegendEntry"] + + +def _wedge_legend_patch( + legend, + orig_handle, + xdescent, + ydescent, + width, + height, + fontsize, +): + """ + Draw wedge-shaped legend keys for pie wedge handles. + """ + center = (-xdescent + width * 0.5, -ydescent + height * 0.5) + radius = 0.5 * min(width, height) + theta1 = float(getattr(orig_handle, "theta1", 0.0)) + theta2 = float(getattr(orig_handle, "theta2", 300.0)) + if theta2 == theta1: + theta2 = theta1 + 300.0 + return mpatches.Wedge(center, radius, theta1=theta1, theta2=theta2) + + +class LegendEntry(mlines.Line2D): + """ + Convenience artist for custom legend entries. + + This is a lightweight wrapper around `matplotlib.lines.Line2D` that + initializes with empty data so it can be passed directly to + `Axes.legend()` or `Figure.legend()` handles. + """ + + def __init__( + self, + label=None, + *, + color=None, + line=True, + marker=None, + linestyle="-", + linewidth=2, + markersize=6, + markerfacecolor=None, + markeredgecolor=None, + markeredgewidth=None, + alpha=None, + **kwargs, + ): + marker = "o" if marker is None and not line else marker + linestyle = "none" if not line else linestyle + if markerfacecolor is None and color is not None: + markerfacecolor = color + if markeredgecolor is None and color is not None: + markeredgecolor = color + super().__init__( + [], + [], + label=label, + color=color, + marker=marker, + linestyle=linestyle, + linewidth=linewidth, + markersize=markersize, + markerfacecolor=markerfacecolor, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + alpha=alpha, + **kwargs, + ) + + @classmethod + def line(cls, label=None, **kwargs): + """ + Build a line-style legend entry. + """ + return cls(label=label, line=True, **kwargs) + + @classmethod + def marker(cls, label=None, marker="o", **kwargs): + """ + Build a marker-style legend entry. + """ + return cls(label=label, line=False, marker=marker, **kwargs) + class Legend(mlegend.Legend): # Soft wrapper of matplotlib legend's class. @@ -15,6 +103,18 @@ class Legend(mlegend.Legend): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @classmethod + def get_default_handler_map(cls): + """ + Extend matplotlib defaults with a wedge handler for pie legends. + """ + handler_map = dict(super().get_default_handler_map()) + handler_map.setdefault( + mpatches.Wedge, + mhandler.HandlerPatch(patch_func=_wedge_legend_patch), + ) + return handler_map + @override def set_loc(self, loc=None): # Sync location setting with the move diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 8071485e8..872adc46a 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -1,6 +1,8 @@ import numpy as np import pandas as pd import pytest +from matplotlib import legend_handler as mhandler +from matplotlib import patches as mpatches import ultraplot as uplt from ultraplot.axes import Axes as UAxes @@ -260,6 +262,59 @@ def test_external_mode_mixing_context_manager(): uplt.close(fig) +def test_legend_entry_helpers(): + h1 = uplt.LegendEntry.line("Line", color="red8", linewidth=3) + h2 = uplt.LegendEntry.marker("Marker", color="blue8", marker="s", markersize=8) + + assert h1.get_linestyle() != "none" + assert h1.get_label() == "Line" + assert h2.get_linestyle() == "None" + assert h2.get_marker() == "s" + assert h2.get_label() == "Marker" + + +def test_legend_entry_with_axes_legend(): + fig, ax = uplt.subplots() + handles = [ + uplt.LegendEntry.line("Trend", color="green7", linewidth=2.5), + uplt.LegendEntry.marker("Samples", color="orange7", marker="o", markersize=7), + ] + leg = ax.legend(handles=handles, loc="best") + + labels = [text.get_text() for text in leg.get_texts()] + assert labels == ["Trend", "Samples"] + lines = leg.get_lines() + assert len(lines) == 2 + assert lines[0].get_linewidth() > 0 + assert lines[1].get_marker() == "o" + uplt.close(fig) + + +def test_pie_legend_uses_wedge_handles(): + fig, ax = uplt.subplots() + wedges, _ = ax.pie([30, 70], labels=["a", "b"]) + leg = ax.legend(wedges, ["a", "b"], loc="best") + handles = leg.legend_handles + assert len(handles) == 2 + assert all(isinstance(handle, mpatches.Wedge) for handle in handles) + uplt.close(fig) + + +def test_pie_legend_handler_map_override(): + fig, ax = uplt.subplots() + wedges, _ = ax.pie([30, 70], labels=["a", "b"]) + leg = ax.legend( + wedges, + ["a", "b"], + loc="best", + handler_map={mpatches.Wedge: mhandler.HandlerPatch()}, + ) + handles = leg.legend_handles + assert len(handles) == 2 + assert all(isinstance(handle, mpatches.Rectangle) for handle in handles) + uplt.close(fig) + + def test_external_mode_toggle_enables_auto(): """ Toggling external mode back off should resume on-the-fly guide creation.