diff --git a/foundry/glm/distributions.py b/foundry/glm/distributions.py index e4ae855..7ad2dbd 100644 --- a/foundry/glm/distributions.py +++ b/foundry/glm/distributions.py @@ -2,8 +2,11 @@ import torch from torch import distributions -from torch.distributions import constraints, Exponential, Weibull -from torch.distributions.utils import broadcast_all +from torch.distributions import constraints, Exponential, Weibull, Poisson, Uniform +from torch.distributions.utils import ( + broadcast_all, + lazy_property +) from foundry.util import log1mexp @@ -13,12 +16,21 @@ def __init__(self, loc: torch.Tensor, dispersion: torch.Tensor, validate_args: Optional[bool] = None): - loc, dispersion = broadcast_all(loc, dispersion) + self.loc, self.dispersion = broadcast_all(loc, dispersion) super().__init__( - total_count=dispersion, - logits=loc.log() - dispersion.log(), + total_count=self.dispersion, + logits=self.loc.log() - self.dispersion.log(), validate_args=validate_args ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(NegativeBinomial, _instance) + batch_shape = torch.Size(batch_shape) + loc = self.loc.expand(batch_shape) + dispersion = self.dispersion.expand(batch_shape) + NegativeBinomial.__init__(new, loc=loc, dispersion=dispersion, validate_args=False) + new._validate_args = self._validate_args + return new class _MultinomialStrict(constraints.Constraint): @@ -149,3 +161,313 @@ def variance(self): if (self.ceiling < 1).any(): raise NotImplementedError("Variance not implemented when ceiling < 1") return super().variance + +def _broadcast_shape(*shapes, **kwargs): + """ + Helper borrowed from pyro source code under Apache License 2.0. + + Similar to ``np.broadcast()`` but for shapes. + Equivalent to ``np.broadcast(*map(np.empty, shapes)).shape``. + :param tuple shapes: shapes of tensors. + :param bool strict: whether to use extend-but-not-resize broadcasting. + :returns: broadcasted shape + :rtype: tuple + :raises: ValueError + """ + strict = kwargs.pop("strict", False) + reversed_shape = [] + for shape in shapes: + for i, size in enumerate(reversed(shape)): + if i >= len(reversed_shape): + reversed_shape.append(size) + elif reversed_shape[i] == 1 and not strict: + reversed_shape[i] = size + elif reversed_shape[i] != size and (size != 1 or strict): + raise ValueError( + "shape mismatch: objects cannot be broadcast to a single shape: {}".format( + " vs ".join(map(str, shapes)) + ) + ) + return tuple(reversed(reversed_shape)) + +class ZeroInflatedDistribution(distributions.Distribution): + """ + Generic zero-inflated distribution, adapted/modified from the pyro package source code (under Apache License 2.0). + + Within foundry, this is intended as a base class for + :class:`ZeroInflatedPoisson` and :class:`ZeroInflatedNegativeBinomial`. + + :param torch.Tensor prob_zero: probability of extra zeros given via a Bernoulli distribution. + :param TorchDistribution base_dist: the base distribution. + """ + + arg_constraints = { + "prob_zero": constraints.unit_interval, + } + + def __init__(self, prob_zero, base_dist, validate_args=None): + + batch_shape = _broadcast_shape(base_dist.batch_shape, prob_zero.shape) + self.prob_zero = prob_zero.expand(batch_shape) + + if base_dist.event_shape: + raise ValueError( + "ZeroInflatedDistribution expected empty " + "base_dist.event_shape but got {}".format(base_dist.event_shape) + ) + + self.base_dist = base_dist.expand(batch_shape) + event_shape = torch.Size() + + super().__init__(batch_shape, event_shape, validate_args) + + @constraints.dependent_property + def support(self): + return self.base_dist.support + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + prob_zero, value = broadcast_all(self.prob_zero, value) + log_prob = (-prob_zero).log1p() + self.base_dist.log_prob(value) + log_prob = torch.where(value == 0, (prob_zero + log_prob.exp()).log(), log_prob) + return log_prob + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + mask = torch.bernoulli(self.prob_zero.expand(shape)).bool() + samples = self.base_dist.expand(shape).sample() + samples = torch.where(mask, samples.new_zeros(()), samples) + return samples + + + @lazy_property + def mean(self): + return (1 - self.prob_zero) * self.base_dist.mean + + @lazy_property + def variance(self): + return (1 - self.prob_zero) * ( + self.base_dist.mean**2 + self.base_dist.variance + ) - (self.mean) ** 2 + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(type(self), _instance) + batch_shape = torch.Size(batch_shape) + prob_zero = self.gate.expand(batch_shape) if "prob_zero" in self.__dict__ else None + base_dist = self.base_dist.expand(batch_shape) + ZeroInflatedDistribution.__init__( + new, prob_zero=prob_zero, base_dist=base_dist, validate_args=False + ) + new._validate_args = self._validate_args + return new + + + +class ZeroInflatedPoisson(ZeroInflatedDistribution): + """ + A Zero Inflated Poisson distribution. + + :param torch.Tensor prob_zero: probability of extra zeros. + :param torch.Tensor rate: rate of poisson distribution. + """ + + arg_constraints = { + "prob_zero": constraints.unit_interval, + "rate": constraints.positive, + } + support = constraints.nonnegative_integer + + def __init__(self, prob_zero, rate, validate_args=None): + + prob_zero, = broadcast_all(prob_zero) + + base_dist = Poisson(rate=rate, validate_args=False) + base_dist._validate_args = validate_args + + super().__init__( + prob_zero=prob_zero, base_dist=base_dist, validate_args=validate_args + ) + + @property + def rate(self): + return self.base_dist.rate + +class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution): + """ + A Zero Inflated Negative Binomial distribution, based on pyro implementation + with foundry parameterization of negative binomial. + + :param torch.Tensor prob_zero: probability of extra zeros. + :param torch.Tensor loc: mean of negative binomial distribution. + :param torch.Tensor dispersion: overdispersion of negative binomial distribution + """ + + arg_constraints = { + "prob_zero": constraints.unit_interval, + "loc": constraints.positive, + "dispersion": constraints.positive, + } + support = constraints.nonnegative_integer + + def __init__(self, prob_zero, loc, dispersion, validate_args=None): + + prob_zero, = broadcast_all(prob_zero) + + base_dist = NegativeBinomial(loc=loc, dispersion=dispersion, validate_args=False) + base_dist._validate_args = validate_args + + super().__init__( + prob_zero=prob_zero, base_dist=base_dist, validate_args=validate_args + ) + + @property + def loc(self): + return self.base_dist.loc + + @property + def dispersion(self): + return self.base_dist.dispersion + + +class HurdleDistribution(distributions.Distribution): + """ + Hurdle base distribution, adapted/modified from the pyro package source code for + zero-inflated distributions (under Apache License 2.0) and formulae from Stan. + + :param torch.Tensor prob_zero: probability of extra zeros given via a Bernoulli distribution. + """ + + arg_constraints = { + "prob_zero": constraints.unit_interval, + } + + def __init__(self, prob_zero, base_dist, validate_args=None): + + batch_shape = _broadcast_shape(base_dist.batch_shape, prob_zero.shape) + self.prob_zero = prob_zero.expand(batch_shape) + + if base_dist.event_shape: + raise ValueError( + "HurdleDistribution expected empty " + "base_dist.event_shape but got {}".format(base_dist.event_shape) + ) + + self.base_dist = base_dist.expand(batch_shape) + event_shape = torch.Size() + + super().__init__(batch_shape, event_shape, validate_args) + + @constraints.dependent_property + def support(self): + return self.base_dist.support + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + zeros, _ = broadcast_all(torch.Tensor([0]), value) + + prob_zero, value = broadcast_all(self.prob_zero, value) + log_prob = (-prob_zero).log1p() + self.base_dist.log_prob(value) - (1 - torch.exp(self.base_dist.log_prob(zeros))).log() + log_prob = torch.where(value == 0, prob_zero.log(), log_prob) + return log_prob + + @lazy_property + def mean(self): + truncated_mean = self.base_dist.mean/(1 - torch.exp(self.base_dist.log_prob(0))) + return (1 - self.prob_zero) * truncated_mean + + @lazy_property + def variance(self): + truncated_mean_sq = (self.base_dist.mean/( + 1 - torch.exp(self.base_dist.log_prob(0)) + ))**2 + + truncated_variance = self.base_dist.variance/( + 1 - torch.exp(self.base_dist.log_prob(0)) + ) - torch.exp(self.base_dist.log_prob(0))*truncated_mean_sq + + return (1 - self.prob_zero)*( + truncated_variance + ) + self.prob_zero*( + 1-self.prob_zero + )*truncated_mean_sq + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(type(self), _instance) + batch_shape = torch.Size(batch_shape) + prob_zero = self.gate.expand(batch_shape) if "prob_zero" in self.__dict__ else None + base_dist = self.base_dist.expand(batch_shape) + HurdleDistribution.__init__( + new, prob_zero=prob_zero, base_dist=base_dist, validate_args=False + ) + new._validate_args = self._validate_args + return new + +class HurdlePoisson(HurdleDistribution): + """ + A Hurdle Poisson distribution. + + :param torch.Tensor prob_zero: probability of zeros. + :param torch.Tensor rate: rate of poisson distribution. + """ + + arg_constraints = { + "prob_zero": constraints.unit_interval, + "rate": constraints.positive, + } + support = constraints.nonnegative_integer + + def __init__(self, prob_zero, rate, validate_args=None): + + prob_zero, = broadcast_all(prob_zero) + + base_dist = Poisson(rate=rate, validate_args=False) + base_dist._validate_args = validate_args + + super().__init__( + prob_zero=prob_zero, base_dist=base_dist, validate_args=validate_args + ) + + @property + def rate(self): + return self.base_dist.rate + +class HurdleNegativeBinomial(HurdleDistribution): + """ + A Hurdle Negative Binomial distribution. + + :param torch.Tensor prob_zero: probability of extra zeros. + :param torch.Tensor loc: mean of negative binomial distribution. + :param torch.Tensor dispersion: overdispersion of negative binomial distribution + """ + + arg_constraints = { + "prob_zero": constraints.unit_interval, + "loc": constraints.positive, + "dispersion": constraints.positive, + } + support = constraints.nonnegative_integer + + def __init__(self, prob_zero, loc, dispersion, validate_args=None): + + prob_zero, = broadcast_all(prob_zero) + + base_dist = NegativeBinomial(loc=loc, dispersion=dispersion, validate_args=False) + base_dist._validate_args = validate_args + + super().__init__( + prob_zero=prob_zero, base_dist=base_dist, validate_args=validate_args + ) + + @property + def loc(self): + return self.base_dist.loc + + @property + def dispersion(self): + return self.base_dist.dispersion + diff --git a/foundry/glm/glm.py b/foundry/glm/glm.py index 625db09..adec9a7 100644 --- a/foundry/glm/glm.py +++ b/foundry/glm/glm.py @@ -23,7 +23,11 @@ Multinomial, Exponential, Weibull, - CeilingWeibull + CeilingWeibull, + ZeroInflatedPoisson, + ZeroInflatedNegativeBinomial, + HurdlePoisson, + HurdleNegativeBinomial ) from foundry.glm.family import Family, SurvivalFamily, FamilyArgs from foundry.glm.util import NoWeightModule, Stopping, SigmoidTransformForClassification, SoftmaxKp1 @@ -112,6 +116,36 @@ 'concentration': transforms.ExpTransform(), 'ceiling': transforms.SigmoidTransform() } + ), + 'zero_inflated_poisson': FamilyArgs( + ZeroInflatedPoisson, + { + 'prob_zero': transforms.SigmoidTransform(), + 'rate': transforms.ExpTransform() + } + ), + 'zero_inflated_negative_binomial': FamilyArgs( + ZeroInflatedNegativeBinomial, + { + 'prob_zero': transforms.SigmoidTransform(), + 'loc': transforms.ExpTransform(), + 'dispersion': transforms.ExpTransform() + } + ), + 'hurdle_poisson': FamilyArgs( + HurdlePoisson, + { + 'prob_zero': transforms.SigmoidTransform(), + 'rate': transforms.ExpTransform() + } + ), + 'hurdle_negative_binomial': FamilyArgs( + HurdleNegativeBinomial, + { + 'prob_zero': transforms.SigmoidTransform(), + 'loc': transforms.ExpTransform(), + 'dispersion': transforms.ExpTransform() + } ) } family_names['gaussian'] = family_names['normal']