Source code for etudes.plotting

"""Plotting module."""

import numpy as np
import matplotlib.pyplot as plt


[docs]def plot_image_grid(ax, images, shape, nrows=20, ncols=None, cmap=None): if ncols is None: ncols = nrows grid = images[:nrows*ncols].reshape(nrows, ncols, *shape).squeeze() return ax.imshow(np.vstack(np.dstack(grid)), cmap=cmap)
[docs]def fill_between_stddev(X_pred, mean_pred, stddev_pred, n=1, ax=None, *args, **kwargs): if ax is None: ax = plt.gca() return ax.fill_between(X_pred, mean_pred - n * stddev_pred, mean_pred + n * stddev_pred, **kwargs)