diff --git a/symbulate/plot.py b/symbulate/plot.py index 4cd3eee..80b5a6d 100644 --- a/symbulate/plot.py +++ b/symbulate/plot.py @@ -1,6 +1,7 @@ import numpy as np import matplotlib.pyplot as plt from scipy.stats import gaussian_kde +from matplotlib.transforms import Affine2D figure = plt.figure @@ -129,6 +130,15 @@ def make_marginal_impulse(count, color, ax_marg, alpha, axis): ax_marg.vlines(key, 0, val, color=color, alpha=alpha) elif axis == 'y': ax_marg.hlines(key, 0, val, color=color, alpha=alpha) + +def make_density(x, ax, color, axis = 'x'): + density = compute_density(x) + xs = np.linspace(x.min(), x.max(), 1000) + if axis == 'x': + ax.plot(xs, density(xs), linewidth=2, color=color) + elif axis == 'y': + ax.plot(xs, density(xs), linewidth=2, color=color, + transform=Affine2D().rotate_deg(270) + ax.transData) def make_density2D(x, y, ax): res = np.vstack([x, y]) diff --git a/symbulate/results.py b/symbulate/results.py index ee53784..919d101 100644 --- a/symbulate/results.py +++ b/symbulate/results.py @@ -12,14 +12,13 @@ from matplotlib.gridspec import GridSpec from matplotlib.ticker import NullFormatter -from matplotlib.transforms import Affine2D from .base import (Arithmetic, Statistical, Comparable, Logical, Filterable, Transformable) from .plot import (configure_axes, get_next_color, is_discrete, count_var, compute_density, add_colorbar, setup_ticks, make_tile, make_violin, - make_marginal_impulse, make_density2D) + make_marginal_impulse, make_density, make_density2D) from .result import (Scalar, Vector, TimeFunction, is_number, is_numeric_vector) from .table import Table @@ -473,9 +472,7 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, if len(type) == 1: plt.ylabel('Relative Frequency') else: - density = compute_density(self.array) - xs = np.linspace(self.array.min(), self.array.max(), 1000) - ax.plot(xs, density(xs), linewidth=2, color=color) + make_density(self.array, ax, color) if len(type) == 1 or (len(type) == 2 and 'rug' in type): plt.ylabel('Density') @@ -498,7 +495,7 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, if 'rug' in type: xs = self.array if discrete: - noise_level = .002 * (self.array.max() - self.array.min()) + noise_level = .002 * (xs.max() - xs.min()) xs = xs + np.random.normal(scale=noise_level, size=n) ax.plot(xs, [0.001] * n, '|', linewidth=5, color='k') if len(type) == 1: @@ -510,10 +507,8 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, x_count = count_var(x) y_count = count_var(y) - x_height = x_count.values() - y_height = y_count.values() - discrete_x = is_discrete(x_height) - discrete_y = is_discrete(y_height) + discrete_x = is_discrete(x_count.values()) + discrete_y = is_discrete(y_count.values()) if type is None: type = ("scatter",) @@ -521,24 +516,18 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, alpha = .5 if bins is None: bins = 10 if 'tile' in type else 30 + + fig = plt.gcf() if 'marginal' in type: - fig = plt.gcf() gs = GridSpec(4, 4) ax = fig.add_subplot(gs[1:4, 0:3]) ax_marg_x = fig.add_subplot(gs[0, 0:3]) ax_marg_y = fig.add_subplot(gs[1:4, 3]) color = get_next_color(ax) if 'density' in type: - densityX = compute_density(x) - densityY = compute_density(y) - x_lines = np.linspace(min(x), max(x), 1000) - y_lines = np.linspace(min(y), max(y), 1000) - ax_marg_x.plot(x_lines, densityX(x_lines), linewidth=2, - color=get_next_color(ax)) - ax_marg_y.plot(y_lines, densityY(y_lines), linewidth=2, - color=get_next_color(ax), - transform=Affine2D().rotate_deg(270) + ax_marg_y.transData) + make_density(x, ax_marg_x, get_next_color(ax)) + make_density(y, ax_marg_y, get_next_color(ax), 'y') else: if discrete_x: make_marginal_impulse(x_count, get_next_color(ax), ax_marg_x, alpha, 'x') @@ -553,7 +542,6 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) else: - fig = plt.gcf() ax = plt.gca() color = get_next_color(ax)