398 lines
14 KiB
Python
398 lines
14 KiB
Python
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""The Dirichlet distribution class."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.ops import check_ops
|
||
|
from tensorflow.python.ops import control_flow_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.ops import random_ops
|
||
|
from tensorflow.python.ops import special_math_ops
|
||
|
from tensorflow.python.ops.distributions import distribution
|
||
|
from tensorflow.python.ops.distributions import kullback_leibler
|
||
|
from tensorflow.python.ops.distributions import util as distribution_util
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"Dirichlet",
|
||
|
]
|
||
|
|
||
|
|
||
|
_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with
|
||
|
dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e.,
|
||
|
`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with
|
||
|
`self.batch_shape() + self.event_shape()`."""
|
||
|
|
||
|
|
||
|
@tf_export("distributions.Dirichlet")
|
||
|
class Dirichlet(distribution.Distribution):
|
||
|
"""Dirichlet distribution.
|
||
|
|
||
|
The Dirichlet distribution is defined over the
|
||
|
[`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive,
|
||
|
length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the
|
||
|
Beta distribution when `k = 2`.
|
||
|
|
||
|
#### Mathematical Details
|
||
|
|
||
|
The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e.,
|
||
|
|
||
|
```none
|
||
|
S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.
|
||
|
```
|
||
|
|
||
|
The probability density function (pdf) is,
|
||
|
|
||
|
```none
|
||
|
pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
|
||
|
Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)
|
||
|
```
|
||
|
|
||
|
where:
|
||
|
|
||
|
* `x in S^{k-1}`, i.e., the `(k-1)`-simplex,
|
||
|
* `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`,
|
||
|
* `Z` is the normalization constant aka the [multivariate beta function](
|
||
|
https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
|
||
|
and,
|
||
|
* `Gamma` is the [gamma function](
|
||
|
https://en.wikipedia.org/wiki/Gamma_function).
|
||
|
|
||
|
The `concentration` represents mean total counts of class occurrence, i.e.,
|
||
|
|
||
|
```none
|
||
|
concentration = alpha = mean * total_concentration
|
||
|
```
|
||
|
|
||
|
where `mean` in `S^{k-1}` and `total_concentration` is a positive real number
|
||
|
representing a mean total count.
|
||
|
|
||
|
Distribution parameters are automatically broadcast in all functions; see
|
||
|
examples for details.
|
||
|
|
||
|
Warning: Some components of the samples can be zero due to finite precision.
|
||
|
This happens more often when some of the concentrations are very small.
|
||
|
Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
|
||
|
density.
|
||
|
|
||
|
Samples of this distribution are reparameterized (pathwise differentiable).
|
||
|
The derivatives are computed using the approach described in the paper
|
||
|
|
||
|
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
|
||
|
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
|
||
|
|
||
|
#### Examples
|
||
|
|
||
|
```python
|
||
|
# Create a single trivariate Dirichlet, with the 3rd class being three times
|
||
|
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
|
||
|
alpha = [1., 2, 3]
|
||
|
dist = tf.distributions.Dirichlet(alpha)
|
||
|
|
||
|
dist.sample([4, 5]) # shape: [4, 5, 3]
|
||
|
|
||
|
# x has one sample, one batch, three classes:
|
||
|
x = [.2, .3, .5] # shape: [3]
|
||
|
dist.prob(x) # shape: []
|
||
|
|
||
|
# x has two samples from one batch:
|
||
|
x = [[.1, .4, .5],
|
||
|
[.2, .3, .5]]
|
||
|
dist.prob(x) # shape: [2]
|
||
|
|
||
|
# alpha will be broadcast to shape [5, 7, 3] to match x.
|
||
|
x = [[...]] # shape: [5, 7, 3]
|
||
|
dist.prob(x) # shape: [5, 7]
|
||
|
```
|
||
|
|
||
|
```python
|
||
|
# Create batch_shape=[2], event_shape=[3]:
|
||
|
alpha = [[1., 2, 3],
|
||
|
[4, 5, 6]] # shape: [2, 3]
|
||
|
dist = tf.distributions.Dirichlet(alpha)
|
||
|
|
||
|
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
|
||
|
|
||
|
x = [.2, .3, .5]
|
||
|
# x will be broadcast as [[.2, .3, .5],
|
||
|
# [.2, .3, .5]],
|
||
|
# thus matching batch_shape [2, 3].
|
||
|
dist.prob(x) # shape: [2]
|
||
|
```
|
||
|
|
||
|
Compute the gradients of samples w.r.t. the parameters:
|
||
|
|
||
|
```python
|
||
|
alpha = tf.constant([1.0, 2.0, 3.0])
|
||
|
dist = tf.distributions.Dirichlet(alpha)
|
||
|
samples = dist.sample(5) # Shape [5, 3]
|
||
|
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
|
||
|
# Unbiased stochastic gradients of the loss function
|
||
|
grads = tf.gradients(loss, alpha)
|
||
|
```
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
concentration,
|
||
|
validate_args=False,
|
||
|
allow_nan_stats=True,
|
||
|
name="Dirichlet"):
|
||
|
"""Initialize a batch of Dirichlet distributions.
|
||
|
|
||
|
Args:
|
||
|
concentration: Positive floating-point `Tensor` indicating mean number
|
||
|
of class occurrences; aka "alpha". Implies `self.dtype`, and
|
||
|
`self.batch_shape`, `self.event_shape`, i.e., if
|
||
|
`concentration.shape = [N1, N2, ..., Nm, k]` then
|
||
|
`batch_shape = [N1, N2, ..., Nm]` and
|
||
|
`event_shape = [k]`.
|
||
|
validate_args: Python `bool`, default `False`. When `True` distribution
|
||
|
parameters are checked for validity despite possibly degrading runtime
|
||
|
performance. When `False` invalid inputs may silently render incorrect
|
||
|
outputs.
|
||
|
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
||
|
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
||
|
result is undefined. When `False`, an exception is raised if one or
|
||
|
more of the statistic's batch members are undefined.
|
||
|
name: Python `str` name prefixed to Ops created by this class.
|
||
|
"""
|
||
|
parameters = dict(locals())
|
||
|
with ops.name_scope(name, values=[concentration]) as name:
|
||
|
self._concentration = self._maybe_assert_valid_concentration(
|
||
|
ops.convert_to_tensor(concentration, name="concentration"),
|
||
|
validate_args)
|
||
|
self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
|
||
|
super(Dirichlet, self).__init__(
|
||
|
dtype=self._concentration.dtype,
|
||
|
validate_args=validate_args,
|
||
|
allow_nan_stats=allow_nan_stats,
|
||
|
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
||
|
parameters=parameters,
|
||
|
graph_parents=[self._concentration,
|
||
|
self._total_concentration],
|
||
|
name=name)
|
||
|
|
||
|
@property
|
||
|
def concentration(self):
|
||
|
"""Concentration parameter; expected counts for that coordinate."""
|
||
|
return self._concentration
|
||
|
|
||
|
@property
|
||
|
def total_concentration(self):
|
||
|
"""Sum of last dim of concentration parameter."""
|
||
|
return self._total_concentration
|
||
|
|
||
|
def _batch_shape_tensor(self):
|
||
|
return array_ops.shape(self.total_concentration)
|
||
|
|
||
|
def _batch_shape(self):
|
||
|
return self.total_concentration.get_shape()
|
||
|
|
||
|
def _event_shape_tensor(self):
|
||
|
return array_ops.shape(self.concentration)[-1:]
|
||
|
|
||
|
def _event_shape(self):
|
||
|
return self.concentration.get_shape().with_rank_at_least(1)[-1:]
|
||
|
|
||
|
def _sample_n(self, n, seed=None):
|
||
|
gamma_sample = random_ops.random_gamma(
|
||
|
shape=[n],
|
||
|
alpha=self.concentration,
|
||
|
dtype=self.dtype,
|
||
|
seed=seed)
|
||
|
return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keepdims=True)
|
||
|
|
||
|
@distribution_util.AppendDocstring(_dirichlet_sample_note)
|
||
|
def _log_prob(self, x):
|
||
|
return self._log_unnormalized_prob(x) - self._log_normalization()
|
||
|
|
||
|
@distribution_util.AppendDocstring(_dirichlet_sample_note)
|
||
|
def _prob(self, x):
|
||
|
return math_ops.exp(self._log_prob(x))
|
||
|
|
||
|
def _log_unnormalized_prob(self, x):
|
||
|
x = self._maybe_assert_valid_sample(x)
|
||
|
return math_ops.reduce_sum((self.concentration - 1.) * math_ops.log(x), -1)
|
||
|
|
||
|
def _log_normalization(self):
|
||
|
return special_math_ops.lbeta(self.concentration)
|
||
|
|
||
|
def _entropy(self):
|
||
|
k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
|
||
|
return (
|
||
|
self._log_normalization()
|
||
|
+ ((self.total_concentration - k)
|
||
|
* math_ops.digamma(self.total_concentration))
|
||
|
- math_ops.reduce_sum(
|
||
|
(self.concentration - 1.) * math_ops.digamma(self.concentration),
|
||
|
axis=-1))
|
||
|
|
||
|
def _mean(self):
|
||
|
return self.concentration / self.total_concentration[..., array_ops.newaxis]
|
||
|
|
||
|
def _covariance(self):
|
||
|
x = self._variance_scale_term() * self._mean()
|
||
|
return array_ops.matrix_set_diag(
|
||
|
-math_ops.matmul(x[..., array_ops.newaxis],
|
||
|
x[..., array_ops.newaxis, :]), # outer prod
|
||
|
self._variance())
|
||
|
|
||
|
def _variance(self):
|
||
|
scale = self._variance_scale_term()
|
||
|
x = scale * self._mean()
|
||
|
return x * (scale - x)
|
||
|
|
||
|
def _variance_scale_term(self):
|
||
|
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
|
||
|
return math_ops.rsqrt(1. + self.total_concentration[..., array_ops.newaxis])
|
||
|
|
||
|
@distribution_util.AppendDocstring(
|
||
|
"""Note: The mode is undefined when any `concentration <= 1`. If
|
||
|
`self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If
|
||
|
`self.allow_nan_stats` is `False` an exception is raised when one or more
|
||
|
modes are undefined.""")
|
||
|
def _mode(self):
|
||
|
k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
|
||
|
mode = (self.concentration - 1.) / (
|
||
|
self.total_concentration[..., array_ops.newaxis] - k)
|
||
|
if self.allow_nan_stats:
|
||
|
nan = array_ops.fill(
|
||
|
array_ops.shape(mode),
|
||
|
np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
|
||
|
name="nan")
|
||
|
return array_ops.where(
|
||
|
math_ops.reduce_all(self.concentration > 1., axis=-1),
|
||
|
mode, nan)
|
||
|
return control_flow_ops.with_dependencies([
|
||
|
check_ops.assert_less(
|
||
|
array_ops.ones([], self.dtype),
|
||
|
self.concentration,
|
||
|
message="Mode undefined when any concentration <= 1"),
|
||
|
], mode)
|
||
|
|
||
|
def _maybe_assert_valid_concentration(self, concentration, validate_args):
|
||
|
"""Checks the validity of the concentration parameter."""
|
||
|
if not validate_args:
|
||
|
return concentration
|
||
|
return control_flow_ops.with_dependencies([
|
||
|
check_ops.assert_positive(
|
||
|
concentration,
|
||
|
message="Concentration parameter must be positive."),
|
||
|
check_ops.assert_rank_at_least(
|
||
|
concentration, 1,
|
||
|
message="Concentration parameter must have >=1 dimensions."),
|
||
|
check_ops.assert_less(
|
||
|
1, array_ops.shape(concentration)[-1],
|
||
|
message="Concentration parameter must have event_size >= 2."),
|
||
|
], concentration)
|
||
|
|
||
|
def _maybe_assert_valid_sample(self, x):
|
||
|
"""Checks the validity of a sample."""
|
||
|
if not self.validate_args:
|
||
|
return x
|
||
|
return control_flow_ops.with_dependencies([
|
||
|
check_ops.assert_positive(x, message="samples must be positive"),
|
||
|
check_ops.assert_near(
|
||
|
array_ops.ones([], dtype=self.dtype),
|
||
|
math_ops.reduce_sum(x, -1),
|
||
|
message="sample last-dimension must sum to `1`"),
|
||
|
], x)
|
||
|
|
||
|
|
||
|
@kullback_leibler.RegisterKL(Dirichlet, Dirichlet)
|
||
|
def _kl_dirichlet_dirichlet(d1, d2, name=None):
|
||
|
"""Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet.
|
||
|
|
||
|
Args:
|
||
|
d1: instance of a Dirichlet distribution object.
|
||
|
d2: instance of a Dirichlet distribution object.
|
||
|
name: (optional) Name to use for created operations.
|
||
|
default is "kl_dirichlet_dirichlet".
|
||
|
|
||
|
Returns:
|
||
|
Batchwise KL(d1 || d2)
|
||
|
"""
|
||
|
with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[
|
||
|
d1.concentration, d2.concentration]):
|
||
|
# The KL between Dirichlet distributions can be derived as follows. We have
|
||
|
#
|
||
|
# Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)]
|
||
|
#
|
||
|
# where B(a) is the multivariate Beta function:
|
||
|
#
|
||
|
# B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n])
|
||
|
#
|
||
|
# The KL is
|
||
|
#
|
||
|
# KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))}
|
||
|
#
|
||
|
# so we'll need to know the log density of the Dirichlet. This is
|
||
|
#
|
||
|
# log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a)
|
||
|
#
|
||
|
# The only term that matters for the expectations is the log(x[i]). To
|
||
|
# compute the expectation of this term over the Dirichlet density, we can
|
||
|
# use the following facts about the Dirichlet in exponential family form:
|
||
|
# 1. log(x[i]) is a sufficient statistic
|
||
|
# 2. expected sufficient statistics (of any exp family distribution) are
|
||
|
# equal to derivatives of the log normalizer with respect to
|
||
|
# corresponding natural parameters: E{T[i](x)} = dA/d(eta[i])
|
||
|
#
|
||
|
# To proceed, we can rewrite the Dirichlet density in exponential family
|
||
|
# form as follows:
|
||
|
#
|
||
|
# Dir(x; a) = exp{eta(a) . T(x) - A(a)}
|
||
|
#
|
||
|
# where '.' is the dot product of vectors eta and T, and A is a scalar:
|
||
|
#
|
||
|
# eta[i](a) = a[i] - 1
|
||
|
# T[i](x) = log(x[i])
|
||
|
# A(a) = log B(a)
|
||
|
#
|
||
|
# Now, we can use fact (2) above to write
|
||
|
#
|
||
|
# E_Dir(x; a)[log(x[i])]
|
||
|
# = dA(a) / da[i]
|
||
|
# = d/da[i] log B(a)
|
||
|
# = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j])
|
||
|
# = digamma(a[i])) - digamma(sum_j a[j])
|
||
|
#
|
||
|
# Putting it all together, we have
|
||
|
#
|
||
|
# KL[Dir(x; a) || Dir(x; b)]
|
||
|
# = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)}
|
||
|
# = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b))
|
||
|
# = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b)
|
||
|
# = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))]
|
||
|
# - lbeta(a) + lbeta(b))
|
||
|
|
||
|
digamma_sum_d1 = math_ops.digamma(
|
||
|
math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True))
|
||
|
digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1
|
||
|
concentration_diff = d1.concentration - d2.concentration
|
||
|
|
||
|
return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
|
||
|
special_math_ops.lbeta(d1.concentration) +
|
||
|
special_math_ops.lbeta(d2.concentration))
|