from deprecated.sphinx import deprecated
import numpy
from convoys import regression
from convoys import single
__all__ = ['KaplanMeier', 'Exponential', 'Weibull', 'Gamma',
'GeneralizedGamma']
class MultiModel:
pass # TODO
class RegressionToMulti(MultiModel):
def __init__(self, *args, **kwargs):
self.base_model = self._base_model_cls(*args, **kwargs)
def fit(self, G, B, T):
''' Fits the model
:param G: numpy vector of shape :math:`n`
:param B: numpy vector of shape :math:`n`
:param T: numpy vector of shape :math:`n`
'''
G = numpy.array(G, dtype=numpy.int)
n, = G.shape
self._n_groups = max(G) + 1
X = numpy.zeros((n, self._n_groups), dtype=numpy.bool)
for i, group in enumerate(G):
X[i,group] = 1
self.base_model.fit(X, B, T)
def _get_x(self, group):
x = numpy.zeros(self._n_groups)
x[group] = 1
return x
def predict(self, group, t):
return self.base_model.predict(self._get_x(group), t)
def predict_ci(self, group, t, ci):
return self.base_model.predict_ci(self._get_x(group), t, ci)
def rvs(self, group, *args, **kwargs):
return self.base_model.rvs(self._get_x(group), *args, **kwargs)
@deprecated(version='0.2.0',
reason='Use :meth:`predict` or :meth:`predict_ci` instead.')
def cdf(self, group, t, ci=None):
'''Returns the predicted values.'''
if ci is not None:
return self.predict_ci(group, t, ci)
else:
return self.predict(group, t)
class SingleToMulti(MultiModel):
def __init__(self, *args, **kwargs):
self.base_model_init = lambda: self._base_model_cls(*args, **kwargs)
def fit(self, G, B, T):
''' Fits the model
:param G: numpy vector of shape :math:`n`
:param B: numpy vector of shape :math:`n`
:param T: numpy vector of shape :math:`n`
'''
group2bt = {}
for g, b, t in zip(G, B, T):
group2bt.setdefault(g, []).append((b, t))
self._group2model = {}
for g, BT in group2bt.items():
self._group2model[g] = self.base_model_init()
self._group2model[g].fit([b for b, t in BT], [t for b, t in BT])
def predict(self, group, t):
return self._group2model[group].predict(t)
def predict_ci(self, group, t, ci):
return self._group2model[group].predict_ci(t, ci)
@deprecated(version='0.2.0',
reason='Use :meth:`predict` or :meth:`predict_ci` instead')
def cdf(self, group, t, ci=None):
'''Returns the predicted values.'''
if ci is not None:
return self.predict_ci(group, t, ci)
else:
return self.predict(group, t)
[docs]class Exponential(RegressionToMulti):
''' Multi-group version of :class:`convoys.regression.Exponential`.'''
_base_model_cls = regression.Exponential
[docs]class Weibull(RegressionToMulti):
''' Multi-group version of :class:`convoys.regression.Weibull`.'''
_base_model_cls = regression.Weibull
[docs]class Gamma(RegressionToMulti):
''' Multi-group version of :class:`convoys.regression.Gamma`.'''
_base_model_cls = regression.Gamma
[docs]class GeneralizedGamma(RegressionToMulti):
''' Multi-group version of :class:`convoys.regression.GeneralizedGamma`.'''
_base_model_cls = regression.GeneralizedGamma
[docs]class KaplanMeier(SingleToMulti):
''' Multi-group version of :class:`convoys.single.KaplanMeier`.'''
_base_model_cls = single.KaplanMeier