Source code for convoys.plotting

import numpy
from matplotlib import pyplot
import convoys.multi

__all__ = ['plot_cohorts']


_models = {
    'kaplan-meier': lambda ci: convoys.multi.KaplanMeier(),
    'exponential': lambda ci: convoys.multi.Exponential(mcmc=ci),
    'weibull': lambda ci: convoys.multi.Weibull(mcmc=ci),
    'gamma': lambda ci: convoys.multi.Gamma(mcmc=ci),
    'generalized-gamma': lambda ci: convoys.multi.GeneralizedGamma(mcmc=ci),
}


[docs]def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', ci=None, ax=None, plot_kwargs={}, plot_ci_kwargs={}, groups=None, specific_groups=None, label_fmt='%(group)s (n=%(n).0f, k=%(k).0f)'): ''' Helper function to fit data using a model and then plot the cohorts. :param G: list with group assignment :param B: list with group assignment :param T: list with group assignment :param t_max: (optional) max value for x axis :param model: (optional, default is kaplan-meier) model to fit. Can be an instance of :class:`multi.MultiModel` or a string identifying the model. One of 'kaplan-meier', 'exponential', 'weibull', 'gamma', or 'generalized-gamma'. :param ci: confidence interval, value from 0-1, or None (default) if no confidence interval is to be plotted :param ax: custom pyplot axis to plot on :param plot_kwargs: extra arguments to pyplot for the lines :param plot_ci_kwargs: extra arguments to pyplot for the confidence intervals :param groups: list of group labels :param specific_groups: subset of groups to plot :param label_fmt: custom format for the labels to use in the legend See :meth:`convoys.utils.get_arrays` which is handy for converting a Pandas dataframe into arrays `G`, `B`, `T`. ''' if model not in _models.keys(): if not isinstance(model, convoys.multi.MultiModel): raise Exception('model incorrectly specified') if groups is None: groups = list(set(G)) if ax is None: ax = pyplot.gca() # Set x scale if t_max is None: _, t_max = ax.get_xlim() t_max = max(t_max, max(T)) if not isinstance(model, convoys.multi.MultiModel): # Fit model m = _models[model](ci=bool(ci)) m.fit(G, B, T) else: m = model if specific_groups is None: specific_groups = groups if len(set(specific_groups).intersection(groups)) != len(specific_groups): raise Exception('specific_groups not a subset of groups!') # Plot t = numpy.linspace(0, t_max, 1000) _, y_max = ax.get_ylim() ax.set_prop_cycle(None) # Reset to first color for i, group in enumerate(specific_groups): j = groups.index(group) # matching index of group n = numpy.sum(G == j) k = numpy.sum(B[G == j]) label = label_fmt % dict(group=group, n=n, k=k) if ci is not None: p_y, p_y_lo, p_y_hi = m.predict_ci(j, t, ci=ci).T merged_plot_ci_kwargs = {'alpha': 0.2} merged_plot_ci_kwargs.update(plot_ci_kwargs) p = ax.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, **merged_plot_ci_kwargs) color = p.get_facecolor()[0] # reuse color for the line else: p_y = m.predict(j, t).T color = None merged_plot_kwargs = {'color': color, 'linewidth': 1.5, 'alpha': 0.7} merged_plot_kwargs.update(plot_kwargs) ax.plot(t, 100. * p_y, label=label, **merged_plot_kwargs) y_max = max(y_max, 110. * max(p_y)) ax.set_xlim([0, t_max]) ax.set_ylim([0, y_max]) ax.set_ylabel('Conversion rate %') ax.grid(True) return m