Source code for curveball.plots

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# This file is part of curveball.
# https://github.com/yoavram/curveball

# Licensed under the MIT license:
# http://www.opensource.org/licenses/MIT-license
# Copyright (c) 2015, Yoav Ram <yoav@yoavram.com>
from __future__ import division
from builtins import range
from past.utils import old_div
import copy
import numpy as np
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt
try:
	from pandas.plotting import lag_plot # py3.6
except ImportError:
	from pandas.tools.plotting import lag_plot # py3.5
import seaborn as sns
sns.set_style("ticks")

from matplotlib.patches import RegularPolygon
from string import ascii_uppercase


[docs]def plot_wells(df, x='Time', y='OD', plot_func=plt.plot, output_filename=None): """Plot a grid of plots, one for each well in the plate. The facetting is done by the ``Row`` and ``Col`` columns of `df`. The colors are given by the ``Color`` column, the labels of the colors are given by the ``Strain`` column. If ``Strain`` is missing then the coloring is done by the ``Well`` column. Parameters ---------- df : pandas.DataFrame growth curve data, see :py:mod:`curveball.ioutils` for a detailed definition. x : str, optional name of column for x-axis, defaults to ``Time``. y : str, optional name of column for y-axis, defaults to ``OD``. plot_func : func, optional function to use for plotting, defaults to :py:func:`matplotlib.pyplot.plot` output_filename : str, optional filename to save the resulting figure; if not given, figure is not saved. Returns ------- seaborn.FacetGrid figure object. """ if 'Strain' in df: hue = 'Strain' palette = df.Color.unique() if 'Color' in df else sns.color_palette() hue_order = df.Strain.unique() palette[palette == '#ffffff'] = '#000000' else: hue = 'Well' palette = sns.color_palette() hue_order = df.Well height = len(df.Row.unique()) width = len(df.Col.unique()) g = sns.FacetGrid(df, hue=hue, col='Col', row='Row', palette=palette, hue_order=hue_order, sharex=True, sharey=True, height=1, aspect=old_div(width,float(height)), despine=True,margin_titles=True) g.map(plot_func, x, y) g.fig.set_figwidth(width) g.fig.set_figheight(height) plt.locator_params(nbins=4) # 4 ticks is enough g.set_axis_labels('','') # remove facets axis labels g.fig.text(0.5, 0, x, size='x-large') # xlabel g.fig.text(-0.01, 0.5, y, size='x-large', rotation='vertical') # ylabel if output_filename: g.savefig(output_filename, bbox_inches='tight', pad_inches=1) return g
[docs]def plot_strains(data, x='Time', y='OD', plot_func=plt.plot, by=None, agg_func=np.mean, hue='Strain', color=None, output_filename=None, **kwargs): """Aggregate by strain and plot the results on one figure with different color for each strain. The grouping of the data is done by the ``Strain`` and either ``Cycle Nr.`` or ``Time`` columns of `data`; the aggregation is done by the `agg_func`, which defaults to :py:func:`numpy.mean`. The colors are given by the ``Color`` column, the labels of the colors are given by the ``Strain`` column of `data`. Parameters ---------- data : pandas.DataFrame growth curve data, see :py:mod:`curveball.ioutils` for a detailed definition. x : str, optional name of column for x-axis, defaults to ``Time``. y : str, optional name of column for y-axis, defaults to ``OD``. plot_func : func, optional function to use for plotting, defaults to :py:func:`matplotlib.pyplot.plot` by : tuple of str, optional used for grouping the data, defaults to ``('Strain', 'Cycle Nr.')`` or ``('Strain', 'Time')``, whichever is available. plot_func : func, optional function to use for aggregating the data, defaults to :py:func:`numpy.mean`. color : seaborn color palette a seaborn color palette to use if there is no ``Color`` column; if not given, using the default palette. output_filename : str, optional filename to save the resulting figure; if not given, figure is not saved. Returns ------- seaborn.FacetGrid figure object. Raises ------ ValueError raised if `by` isn't set and `data` doesn't contain ``Strain`` and either ``Time`` or ``Cycle Nr.``. """ if 'Color' in data: palette = data.Color.unique() palette[palette == '#ffffff'] = '#000000' else: palette = color or sns.color_palette() if by is None: if 'Cycle Nr.' in data and 'Strain' in data: by = ['Strain', 'Cycle Nr.'] elif 'Time' in data and 'Strain' in data: by = ['Strain', 'Time'] else: raise ValueError("If by is not set then data must have column Strain and either Time or Cycle Nr.") grp = data.groupby(by=by) agg = grp.aggregate(agg_func).reset_index() g = sns.FacetGrid(agg, hue=hue, height=5, aspect=1.5, palette=palette, hue_order=data[hue].unique()) g.map(plot_func, x, y); g.add_legend() if output_filename: g.savefig(output_filename, bbox_inches='tight', pad_inches=1) return g
[docs]def tsplot(data, x='Time', y='OD', ci_level=95, ax=None, color=None, output_filename=None, **kwargs): """Time series plot of the data by strain (if applicable) or well. The grouping of the data is done by the value of `x` and ``Strain``, if such a column exists in `data`; otherwise it is done by `x` and ``Well``. The aggregation is done by :py:func:`seaborn.lineplot` which calculates the mean with a confidence interval. The colors are given by the ``Color`` column, the labels of the colors are given by the ``Strain`` column; if ``Strain`` and ``Color`` don't exist in `data` then the function will use a default palette and color the lines by well. Parameters ---------- data : pandas.DataFrame growth curve data, see :py:mod:`curveball.ioutils` for a detailed definition. x : str, optional name of column for x-axis, defaults to ``Time``. y : str, optional name of column for y-axis, defaults to ``OD``. ci_level : int, optional confidence interval width in precent (0-100), defaults to 95. ax : matplotlib.axes.Axes, optional plot into this axes, if not given create a new figure. color : seaborn color palette a seaborn color palette to use if there is no ``Color`` column; if not given, using the default palette. output_filename : str, optional filename to save the resulting figure; if not given, figure is not saved. Returns ------- matplotlib.axes.Axes axes object """ if 'Strain' in data: condition = 'Strain' else: condition = 'Well' if 'Color' in data: palette = data['Color'].unique() palette[palette == '#ffffff'] = '#000000' else: palette = color or sns.color_palette() g = sns.lineplot(data=data, x=x, hue=condition, y=y, err_style='band', ci=ci_level, palette=list(palette), ax=ax) sns.despine() if output_filename: g.figure.savefig(output_filename, bbox_inches='tight', pad_inches=1) return g
[docs]def plot_plate(data, edge_color='#888888', output_filename=None): """Plot of the plate color mapping. The function will plot the color mapping in `data`: a grid with enough columns and rows for the ``Col`` and ``Row`` columns in `data`, where the color of each grid cell given by the ``Color`` column. Parameters ---------- data : pandas.DataFrame growth curve data, see :py:mod:`curveball.ioutils` for a detailed definition. edge_color : str color hex string for the grid edges. Returns ------- fig : matplotlib.figure.Figure figure object ax : numpy.ndarray array of axis objects. """ plate = data.pivot('Row', 'Col', 'Color').values height, width = plate.shape fig = plt.figure(figsize=((width + 2.0) / 3.0, (height + 2.0) / 3.0)) ax = fig.add_axes((0.05, 0.05, 0.9, 0.9), aspect='equal', frameon=False, xlim=(-0.05, width + 0.05), ylim=(-0.05, height + 0.05)) for axis in (ax.xaxis, ax.yaxis): axis.set_major_formatter(plt.NullFormatter()) axis.set_major_locator(plt.NullLocator()) # Create the grid of squares squares = np.array([[RegularPolygon((i + 0.5, j + 0.5), numVertices=4, radius=0.5 * np.sqrt(2), orientation=old_div(np.pi, 4), ec=edge_color, fc=plate[height-1-j,i]) for j in range(height)] for i in range(width)]) [ax.add_patch(sq) for sq in squares.flat] ax.set_xticks(np.arange(width) + 0.5) ax.set_xticklabels(np.arange(1, 1 + width)) ax.set_yticks(np.arange(height) + 0.5) ax.set_yticklabels(ascii_uppercase[height-1::-1]) ax.xaxis.tick_top() ax.yaxis.tick_left() ax.tick_params(length=0, width=0) if output_filename: fig.savefig(output_filename, bbox_inches='tight', pad_inches=1) return fig, ax
[docs]def plot_params_distribution(param_samples, color='k', cmap="viridis", alpha=None): """Plots a distribution of model parameter samples generated with :py:func:`curveball.models.sample_params`. Parameters ---------- param_samples : pandas.DataFrame data frame of samples; each row is one sample, each column is one parameter. alpha : float transparency of plot markers, defaults to :math:`1/n^{1/4}` where *n* is number of rows in `param_samples`. Returns ------- seaborn.Grid figure object """ nsamples = param_samples.shape[0] g = sns.PairGrid(param_samples) if alpha is None: alpha = 1.0 / np.power(nsamples, 1.0 / 4.0) g.map_upper(plt.scatter, alpha=alpha, color=color) g.map_lower(sns.kdeplot, cmap=cmap, legend=False, shade=True, shade_lowest=False) g.map_diag(plt.hist, facecolor=color) # https://github.com/mwaskom/seaborn/pull/788 return g
def _plot_fitted_histogram(data, rv=scipy.stats.norm, color='k', label=None, alpha=0.5, ax=None): """This is basically `sns.distplot(fit=rv)`. TODO: `low,high = np.percentile(x, 2.5), np.percentile(x, 97.5)` """ if ax is None: fig, ax = plt.subplots(1, 1) else: fig = ax.figure rv_params = rv.fit(data) rv_inst = rv(*rv_params) nbins = min(100, len(data)) n, bins, patches = ax.hist(data, bins=nbins, color=color, alpha=alpha, density=True) ax.plot(bins, rv_inst.pdf(bins), color='k', lw=2) ax.annotate( r'$\mu={:.2g}, \sigma={:.2g}$'.format(rv_inst.mean(), rv_inst.std()), xy=(bins[len(bins)//2], np.max(n)), xycoords="data", horizontalalignment='center', fontsize=plt.rcParams['axes.labelsize'] ) return fig, ax
[docs]def plot_model_residuals(model_fit, rv=scipy.stats.norm, color='k'): """Plot of the residuals of a model fit. The function will plot the residuals - the difference between data and model - for a given model fit. The left panel shows the residuals over time; the right panel shows the histogram of the residuals with a fitted distribution curve. Parameters ---------- model_fit : lmfit.ModelResult the result of a model fitting procedure. rv : scipy.stats.rv_continuous, optional :py:class:`scipy.stats.rv_continuous` random variable whose probability density function (pdf) will be fitted to the histogram. Defaults to the normal distribution (`scipy.stats.norm`). color : str, optional color string for the plot, defaults to `k` for black. Returns ------- fig : matplotlib.figure.Figure figure object ax : numpy.ndarray array of axis objects. """ w, h= plt.rcParams['figure.figsize'] fig,ax = plt.subplots(1, 2, figsize=(w * 2, h)) model_fit.plot_residuals(ax=ax[0], data_kws={'color': color}) # removed, causes bug in lmfit: fit_kws={'color': color}) ax[0].set_xlabel('Time (hr)') ax[0].set_ylabel('Residuals') ax[0].legend().set_visible(False) ax[0].set_title('') _plot_fitted_histogram(model_fit.residual, rv=rv, color=color, ax=ax[1]) ax[1].set(xlabel='Residuals', ylabel='Frequency') fig.tight_layout() sns.despine() return fig, ax
[docs]def plot_residuals(df, time='Time', value='OD', resid_func=lambda x: x - x.mean(), rv=scipy.stats.norm, color='k', ax=None): """Plot of the residuals of in the data. The function will plot the residuals - the difference between data and average at each time point. The left panel shows the residuals over time. The middle panel shows the histogram of the residuals with a fitted distribution (defaults to Gaussian). The right panel shows the regression between the standard deviation at time `t+1` and `t` to identify autocorrelation. Parameters ---------- df : pandas.DataFrame a data frame with columns ``Time`` and ``OD``. time : str, optional name of column over which to group and plot the residuals. Defaults to ``Time``. value : str, optional name of column in `df` of the value on which to compute the residuals. Defaults to ``OD``. resid_func : function, optional function to calculate residuals. Defaults to ``x - x.mean()``. rv : scipy.stats.rv_continuous, optional :py:class:`scipy.stats.rv_continuous` random variable whose probability density function (pdf) will be fitted to the histogram. Defaults the normal distribution (:py:class:`scipy.stats.norm`). color : str, optional color string for the plot, defaults to `k` for black. Returns ------- fig : matplotlib.figure.Figure figure object ax : numpy.ndarray array of axis objects. """ w, h= plt.rcParams['figure.figsize'] fig,ax = plt.subplots(1, 3, figsize=(w * 3, h)) residuals = df.groupby(time)[value].transform(resid_func).values ax[0].plot(df[time], residuals, ls='', marker='o', color=color) ax[0].set(xlabel=time, ylabel='Residuals') _plot_fitted_histogram(residuals, rv=rv, color=color, ax=ax[1]) ax[1].set(xlabel='Residuals', ylabel='Frequency') sigmas = df.groupby(time)[value].std() linreg = scipy.stats.linregress(sigmas.values[:-1], sigmas.values[1:]) eq = r'$\sigma_{{t+1}} = {:.2g} + {:.2g} \sigma_{{t}}$'.format(linreg.intercept, linreg.slope) sigma_range = np.linspace(sigmas.min(), sigmas.max()) ax[2].plot(sigma_range, sigma_range, color='k', ls='--', label=r'$\sigma_{t+1}=\sigma_{t}$') ax[2].plot(sigma_range, linreg.intercept + linreg.slope * sigma_range, color=color, label=eq) lag_plot(sigmas, c='k', ax=ax[2]) ax[2].set(xlabel=r'$\sigma_{t}$', ylabel=r'$\sigma_{t+1}$') ax[2].legend(loc='upper left') fig.tight_layout() sns.despine() return fig, ax
[docs]def plot_sample_fit(model_fit, param_samples, fit_kws=None, data_kws=None, sample_kws=None): """Plot of sampled curve fits. The function will plot the main model fit and the sampled curve fits based on a table of sample parameters. Parameters ---------- model_fit : lmfit.ModelResult the result of a model fitting procedure. param_samples : pandas.DataFrame data frame of samples; each row is one sample, each column is one parameter. fit_kws, data_kws, sample_kws : dict dictionaries of plot directives for the fit, data, and sampled fit curves. Returns ------- fig : matplotlib.figure.Figure figure object ax : numpy.ndarray array of axis objects. """ t = np.linspace(0, model_fit.userkws['t'].max()) def f(params): return model_fit.model.eval(t=t, params=params) nsamples = param_samples.shape[0] _fit_kws = dict(linewidth=5) if fit_kws: _fit_kws.update(fit_kws) _data_kws = dict(marker='.') if data_kws: _data_kws.update(data_kws) _sample_kws = dict(linestyle='--', color='gray', alpha=1/np.sqrt(nsamples)) if sample_kws: _sample_kws.update(sample_kws) ax = model_fit.plot_fit(init_kws={'ls': ''}, fit_kws=_fit_kws, data_kws=_data_kws) for i in range(nsamples): sample = param_samples.iloc[i, :] params = model_fit.params.copy() for k, v in params.items(): if v.vary: params[k].set(value=sample[k]) plt.plot(t, f(params), **_sample_kws) ax.legend().set_visible(False) ax.set(ylabel='OD', xlabel='Time', title='') sns.despine() return ax.figure, ax