# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The Beta distribution class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export __all__ = [ "Beta", "BetaWithSoftplusConcentration", ] _beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" @tf_export("distributions.Beta") class Beta(distribution.Distribution): """Beta distribution. The Beta distribution is defined over the `(0, 1)` interval using parameters `concentration1` (aka "alpha") and `concentration0` (aka "beta"). #### Mathematical Details The probability density function (pdf) is, ```none pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) ``` where: * `concentration1 = alpha`, * `concentration0 = beta`, * `Z` is the normalization constant, and, * `Gamma` is the [gamma function]( https://en.wikipedia.org/wiki/Gamma_function). The concentration parameters represent mean total counts of a `1` or a `0`, i.e., ```none concentration1 = alpha = mean * total_concentration concentration0 = beta = (1. - mean) * total_concentration ``` where `mean` in `(0, 1)` and `total_concentration` is a positive real number representing a mean `total_count = concentration1 + concentration0`. Distribution parameters are automatically broadcast in all functions; see examples for details. Warning: The samples can be zero due to finite precision. This happens more often when some of the concentrations are very small. Make sure to round the samples to `np.finfo(dtype).tiny` before computing the density. Samples of this distribution are reparameterized (pathwise differentiable). The derivatives are computed using the approach described in the paper [Michael Figurnov, Shakir Mohamed, Andriy Mnih. Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) #### Examples ```python # Create a batch of three Beta distributions. alpha = [1, 2, 3] beta = [1, 2, 3] dist = tf.distributions.Beta(alpha, beta) dist.sample([4, 5]) # Shape [4, 5, 3] # `x` has three batch entries, each with two samples. x = [[.1, .4, .5], [.2, .3, .5]] # Calculate the probability of each pair of samples under the corresponding # distribution in `dist`. dist.prob(x) # Shape [2, 3] ``` ```python # Create batch_shape=[2, 3] via parameter broadcast: alpha = [[1.], [2]] # Shape [2, 1] beta = [3., 4, 5] # Shape [3] dist = tf.distributions.Beta(alpha, beta) # alpha broadcast as: [[1., 1, 1,], # [2, 2, 2]] # beta broadcast as: [[3., 4, 5], # [3, 4, 5]] # batch_Shape [2, 3] dist.sample([4, 5]) # Shape [4, 5, 2, 3] x = [.2, .3, .5] # x will be broadcast as [[.2, .3, .5], # [.2, .3, .5]], # thus matching batch_shape [2, 3]. dist.prob(x) # Shape [2, 3] ``` Compute the gradients of samples w.r.t. the parameters: ```python alpha = tf.constant(1.0) beta = tf.constant(2.0) dist = tf.distributions.Beta(alpha, beta) samples = dist.sample(5) # Shape [5] loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function # Unbiased stochastic gradients of the loss function grads = tf.gradients(loss, [alpha, beta]) ``` """ def __init__(self, concentration1=None, concentration0=None, validate_args=False, allow_nan_stats=True, name="Beta"): """Initialize a batch of Beta distributions. Args: concentration1: Positive floating-point `Tensor` indicating mean number of successes; aka "alpha". Implies `self.dtype` and `self.batch_shape`, i.e., `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. concentration0: Positive floating-point `Tensor` indicating mean number of failures; aka "beta". Otherwise has same semantics as `concentration1`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with ops.name_scope(name, values=[concentration1, concentration0]) as name: self._concentration1 = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration1, name="concentration1"), validate_args) self._concentration0 = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration0, name="concentration0"), validate_args) check_ops.assert_same_float_dtype([ self._concentration1, self._concentration0]) self._total_concentration = self._concentration1 + self._concentration0 super(Beta, self).__init__( dtype=self._total_concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration1, self._concentration0, self._total_concentration], name=name) @staticmethod def _param_shapes(sample_shape): return dict(zip( ["concentration1", "concentration0"], [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)) @property def concentration1(self): """Concentration parameter associated with a `1` outcome.""" return self._concentration1 @property def concentration0(self): """Concentration parameter associated with a `0` outcome.""" return self._concentration0 @property def total_concentration(self): """Sum of concentration parameters.""" return self._total_concentration def _batch_shape_tensor(self): return array_ops.shape(self.total_concentration) def _batch_shape(self): return self.total_concentration.get_shape() def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): return tensor_shape.scalar() def _sample_n(self, n, seed=None): expanded_concentration1 = array_ops.ones_like( self.total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = array_ops.ones_like( self.total_concentration, dtype=self.dtype) * self.concentration0 gamma1_sample = random_ops.random_gamma( shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed) gamma2_sample = random_ops.random_gamma( shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "beta")) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample @distribution_util.AppendDocstring(_beta_sample_note) def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() @distribution_util.AppendDocstring(_beta_sample_note) def _prob(self, x): return math_ops.exp(self._log_prob(x)) @distribution_util.AppendDocstring(_beta_sample_note) def _log_cdf(self, x): return math_ops.log(self._cdf(x)) @distribution_util.AppendDocstring(_beta_sample_note) def _cdf(self, x): return math_ops.betainc(self.concentration1, self.concentration0, x) def _log_unnormalized_prob(self, x): x = self._maybe_assert_valid_sample(x) return ((self.concentration1 - 1.) * math_ops.log(x) + (self.concentration0 - 1.) * math_ops.log1p(-x)) def _log_normalization(self): return (math_ops.lgamma(self.concentration1) + math_ops.lgamma(self.concentration0) - math_ops.lgamma(self.total_concentration)) def _entropy(self): return ( self._log_normalization() - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1) - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0) + ((self.total_concentration - 2.) * math_ops.digamma(self.total_concentration))) def _mean(self): return self._concentration1 / self._total_concentration def _variance(self): return self._mean() * (1. - self._mean()) / (1. + self.total_concentration) @distribution_util.AppendDocstring( """Note: The mode is undefined when `concentration1 <= 1` or `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If `self.allow_nan_stats` is `False` an exception is raised when one or more modes are undefined.""") def _mode(self): mode = (self.concentration1 - 1.) / (self.total_concentration - 2.) if self.allow_nan_stats: nan = array_ops.fill( self.batch_shape_tensor(), np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), name="nan") is_defined = math_ops.logical_and(self.concentration1 > 1., self.concentration0 > 1.) return array_ops.where(is_defined, mode, nan) return control_flow_ops.with_dependencies([ check_ops.assert_less( array_ops.ones([], dtype=self.dtype), self.concentration1, message="Mode undefined for concentration1 <= 1."), check_ops.assert_less( array_ops.ones([], dtype=self.dtype), self.concentration0, message="Mode undefined for concentration0 <= 1.") ], mode) def _maybe_assert_valid_concentration(self, concentration, validate_args): """Checks the validity of a concentration parameter.""" if not validate_args: return concentration return control_flow_ops.with_dependencies([ check_ops.assert_positive( concentration, message="Concentration parameter must be positive."), ], concentration) def _maybe_assert_valid_sample(self, x): """Checks the validity of a sample.""" if not self.validate_args: return x return control_flow_ops.with_dependencies([ check_ops.assert_positive(x, message="sample must be positive"), check_ops.assert_less( x, array_ops.ones([], self.dtype), message="sample must be less than `1`."), ], x) class BetaWithSoftplusConcentration(Beta): """Beta with softplus transform of `concentration1` and `concentration0`.""" def __init__(self, concentration1, concentration0, validate_args=False, allow_nan_stats=True, name="BetaWithSoftplusConcentration"): parameters = dict(locals()) with ops.name_scope(name, values=[concentration1, concentration0]) as name: super(BetaWithSoftplusConcentration, self).__init__( concentration1=nn.softplus(concentration1, name="softplus_concentration1"), concentration0=nn.softplus(concentration0, name="softplus_concentration0"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters @kullback_leibler.RegisterKL(Beta, Beta) def _kl_beta_beta(d1, d2, name=None): """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta. Args: d1: instance of a Beta distribution object. d2: instance of a Beta distribution object. name: (optional) Name to use for created operations. default is "kl_beta_beta". Returns: Batchwise KL(d1 || d2) """ def delta(fn, is_property=True): fn1 = getattr(d1, fn) fn2 = getattr(d2, fn) return (fn2 - fn1) if is_property else (fn2() - fn1()) with ops.name_scope(name, "kl_beta_beta", values=[ d1.concentration1, d1.concentration0, d1.total_concentration, d2.concentration1, d2.concentration0, d2.total_concentration, ]): return (delta("_log_normalization", is_property=False) - math_ops.digamma(d1.concentration1) * delta("concentration1") - math_ops.digamma(d1.concentration0) * delta("concentration0") + (math_ops.digamma(d1.total_concentration) * delta("total_concentration")))