laywerrobot/lib/python3.6/site-packages/tensorflow/python/ops/distributions/util.py
2020-08-27 21:55:39 +02:00

1433 lines
53 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for probability distributions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import hashlib
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util import tf_inspect
def assert_integer_form(
x, data=None, summarize=None, message=None,
int_dtype=None, name="assert_integer_form"):
"""Assert that x has integer components (or floats equal to integers).
Args:
x: Floating-point `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
implies the smallest possible signed int will be used for casting.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
"""
with ops.name_scope(name, values=[x, data]):
x = ops.convert_to_tensor(x, name="x")
if x.dtype.is_integer:
return control_flow_ops.no_op()
message = message or "{} has non-integer components".format(x)
if int_dtype is None:
try:
int_dtype = {
dtypes.float16: dtypes.int16,
dtypes.float32: dtypes.int32,
dtypes.float64: dtypes.int64,
}[x.dtype.base_dtype]
except KeyError:
raise TypeError("Unrecognized type {}".format(x.dtype.name))
return check_ops.assert_equal(
x, math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
data=data, summarize=summarize, message=message, name=name)
def assert_symmetric(matrix):
matrix_t = array_ops.matrix_transpose(matrix)
return control_flow_ops.with_dependencies(
[check_ops.assert_equal(matrix, matrix_t)], matrix)
def embed_check_nonnegative_integer_form(
x, name="embed_check_nonnegative_integer_form"):
"""Assert x is a non-negative tensor, and optionally of integers."""
with ops.name_scope(name, values=[x]):
x = ops.convert_to_tensor(x, name="x")
assertions = [
check_ops.assert_non_negative(
x, message="'{}' must be non-negative.".format(x)),
]
if not x.dtype.is_integer:
assertions += [
assert_integer_form(
x, message="'{}' cannot contain fractional components.".format(
x)),
]
return control_flow_ops.with_dependencies(assertions, x)
def same_dynamic_shape(a, b):
"""Returns whether a and b have the same dynamic shape.
Args:
a: `Tensor`
b: `Tensor`
Returns:
`bool` `Tensor` representing if both tensors have the same shape.
"""
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
# Here we can't just do math_ops.equal(a.shape, b.shape), since
# static shape inference may break the equality comparison between
# shape(a) and shape(b) in math_ops.equal.
def all_shapes_equal():
return math_ops.reduce_all(math_ops.equal(
array_ops.concat([array_ops.shape(a), array_ops.shape(b)], 0),
array_ops.concat([array_ops.shape(b), array_ops.shape(a)], 0)))
# One of the shapes isn't fully defined, so we need to use the dynamic
# shape.
return control_flow_ops.cond(
math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
all_shapes_equal,
lambda: constant_op.constant(False))
def maybe_get_static_value(x, dtype=None):
"""Helper which tries to return a static value.
Given `x`, extract it's value statically, optionally casting to a specific
dtype. If this is not possible, None is returned.
Args:
x: `Tensor` for which to extract a value statically.
dtype: Optional dtype to cast to.
Returns:
Statically inferred value if possible, otherwise None.
"""
if x is None:
return x
try:
# This returns an np.ndarray.
x_ = tensor_util.constant_value(x)
except TypeError:
x_ = x
if x_ is None or dtype is None:
return x_
return np.array(x_, dtype)
def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
validate_args=False,
name="get_logits_and_probs"):
"""Converts logit to probabilities (or vice-versa), and returns both.
Args:
logits: Floating-point `Tensor` representing log-odds.
probs: Floating-point `Tensor` representing probabilities.
multidimensional: Python `bool`, default `False`.
If `True`, represents whether the last dimension of `logits` or `probs`,
a `[N1, N2, ... k]` dimensional tensor, representing the
logit or probability of `shape[-1]` classes.
validate_args: Python `bool`, default `False`. When `True`, either assert
`0 <= probs <= 1` (if not `multidimensional`) or that the last dimension
of `probs` sums to one.
name: A name for this operation (optional).
Returns:
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
`1`, then the corresponding entry in the returned logit will be `-Inf` and
`Inf` respectively.
Raises:
ValueError: if neither `probs` nor `logits` were passed in, or both were.
"""
with ops.name_scope(name, values=[probs, logits]):
if (probs is None) == (logits is None):
raise ValueError("Must pass probs or logits, but not both.")
if probs is None:
logits = ops.convert_to_tensor(logits, name="logits")
if not logits.dtype.is_floating:
raise TypeError("logits must having floating type.")
# We can early return since we constructed probs and therefore know
# they're valid.
if multidimensional:
if validate_args:
logits = embed_check_categorical_event_shape(logits)
return logits, nn.softmax(logits, name="probs")
return logits, math_ops.sigmoid(logits, name="probs")
probs = ops.convert_to_tensor(probs, name="probs")
if not probs.dtype.is_floating:
raise TypeError("probs must having floating type.")
if validate_args:
with ops.name_scope("validate_probs"):
one = constant_op.constant(1., probs.dtype)
dependencies = [check_ops.assert_non_negative(probs)]
if multidimensional:
probs = embed_check_categorical_event_shape(probs)
dependencies += [
check_ops.assert_near(
math_ops.reduce_sum(probs, -1),
one,
message="probs does not sum to 1.")
]
else:
dependencies += [check_ops.assert_less_equal(
probs, one, message="probs has components greater than 1.")]
probs = control_flow_ops.with_dependencies(dependencies, probs)
with ops.name_scope("logits"):
if multidimensional:
# Here we don't compute the multidimensional case, in a manner
# consistent with respect to the unidimensional case. We do so
# following the TF convention. Typically, you might expect to see
# logits = log(probs) - log(probs[pivot]). A side-effect of
# being consistent with the TF approach is that the unidimensional case
# implicitly handles the second dimension but the multidimensional case
# explicitly keeps the pivot dimension.
return math_ops.log(probs), probs
return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
def _is_known_unsigned_by_dtype(dt):
"""Helper returning True if dtype is known to be unsigned."""
return {
dtypes.bool: True,
dtypes.uint8: True,
dtypes.uint16: True,
}.get(dt.base_dtype, False)
def _is_known_signed_by_dtype(dt):
"""Helper returning True if dtype is known to be signed."""
return {
dtypes.float16: True,
dtypes.float32: True,
dtypes.float64: True,
dtypes.int8: True,
dtypes.int16: True,
dtypes.int32: True,
dtypes.int64: True,
}.get(dt.base_dtype, False)
def _is_known_dtype(dt):
"""Helper returning True if dtype is known."""
return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt)
def _largest_integer_by_dtype(dt):
"""Helper returning the largest integer exactly representable by dtype."""
if not _is_known_dtype(dt):
raise TypeError("Unrecognized dtype: {}".format(dt.name))
if dt.is_floating:
return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1))
if dt.is_integer:
return np.iinfo(dt.as_numpy_dtype).max
if dt.base_dtype == dtypes.bool:
return int(1)
# We actually can't land here but keep the case for completeness.
raise TypeError("Unrecognized dtype: {}".format(dt.name))
def _smallest_integer_by_dtype(dt):
"""Helper returning the smallest integer exactly representable by dtype."""
if not _is_known_dtype(dt):
raise TypeError("Unrecognized dtype: {}".format(dt.name))
if _is_known_unsigned_by_dtype(dt):
return 0
return -1 * _largest_integer_by_dtype(dt)
def _is_integer_like_by_dtype(dt):
"""Helper returning True if dtype.is_integer or is `bool`."""
if not _is_known_dtype(dt):
raise TypeError("Unrecognized dtype: {}".format(dt.name))
return dt.is_integer or dt.base_dtype == dtypes.bool
def embed_check_categorical_event_shape(
categorical_param,
name="embed_check_categorical_event_shape"):
"""Embeds checks that categorical distributions don't have too many classes.
A categorical-type distribution is one which, e.g., returns the class label
rather than a one-hot encoding. E.g., `Categorical(probs)`.
Since distributions output samples in the same dtype as the parameters, we
must ensure that casting doesn't lose precision. That is, the
`parameter.dtype` implies a maximum number of classes. However, since shape is
`int32` and categorical variables are presumed to be indexes into a `Tensor`,
we must also ensure that the number of classes is no larger than the largest
possible `int32` index, i.e., `2**31-1`.
In other words the number of classes, `K`, must satisfy the following
condition:
```python
K <= min(
int(2**31 - 1), # Largest float as an index.
{
dtypes.float16: int(2**11), # Largest int as a float16.
dtypes.float32: int(2**24),
dtypes.float64: int(2**53),
}.get(categorical_param.dtype.base_dtype, 0))
```
Args:
categorical_param: Floating-point `Tensor` representing parameters of
distribution over categories. The rightmost shape is presumed to be the
number of categories.
name: A name for this operation (optional).
Returns:
categorical_param: Input `Tensor` with appropriate assertions embedded.
Raises:
TypeError: if `categorical_param` has an unknown `dtype`.
ValueError: if we can statically identify `categorical_param` as being too
large (for being closed under int32/float casting).
"""
with ops.name_scope(name, values=[categorical_param]):
x = ops.convert_to_tensor(categorical_param, name="categorical_param")
# The size must not exceed both of:
# - The largest possible int32 (since categorical values are presumed to be
# indexes into a Tensor).
# - The largest possible integer exactly representable under the given
# floating-point dtype (since we need to cast to/from).
#
# The chosen floating-point thresholds are 2**(1 + mantissa_bits).
# For more details, see:
# https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
x_dtype = x.dtype.base_dtype
max_event_size = (_largest_integer_by_dtype(x_dtype)
if x_dtype.is_floating else 0)
if max_event_size is 0:
raise TypeError("Unable to validate size of unrecognized dtype "
"({}).".format(x_dtype.name))
try:
x_shape_static = x.get_shape().with_rank_at_least(1)
except ValueError:
raise ValueError("A categorical-distribution parameter must have "
"at least 1 dimension.")
if x_shape_static[-1].value is not None:
event_size = x_shape_static[-1].value
if event_size < 2:
raise ValueError("A categorical-distribution parameter must have at "
"least 2 events.")
if event_size > max_event_size:
raise ValueError(
"Number of classes exceeds `dtype` precision, i.e., "
"{} implies shape ({}) cannot exceed {}.".format(
x_dtype.name, event_size, max_event_size))
return x
else:
event_size = array_ops.shape(x, name="x_shape")[-1]
return control_flow_ops.with_dependencies([
check_ops.assert_rank_at_least(
x, 1, message=("A categorical-distribution parameter must have "
"at least 1 dimension.")),
check_ops.assert_greater_equal(
array_ops.shape(x)[-1], 2,
message=("A categorical-distribution parameter must have at "
"least 2 events.")),
check_ops.assert_less_equal(
event_size, max_event_size,
message="Number of classes exceeds `dtype` precision, "
"i.e., {} dtype cannot exceed {} shape.".format(
x_dtype.name, max_event_size)),
], x)
def embed_check_integer_casting_closed(
x,
target_dtype,
assert_nonnegative=True,
name="embed_check_casting_closed"):
"""Ensures integers remain unaffected despite casting to/from int/float types.
Example integer-types: `uint8`, `int32`, `bool`.
Example floating-types: `float32`, `float64`.
The largest possible integer representable by an IEEE754 floating-point is
`2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is
`2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have
integer-form values can be cast to some other type without loss of precision.
The smallest representable integer is the negative of the largest
representable integer, except for types: `uint8`, `uint16`, `bool`. For these
types, the smallest representable integer is `0`.
Args:
x: `Tensor` representing integer-form values.
target_dtype: TF `dtype` under which `x` should have identical values.
assert_nonnegative: `bool` indicating `x` should contain nonnegative values.
name: A name for this operation (optional).
Returns:
x: Input `Tensor` with appropriate assertions embedded.
Raises:
TypeError: if `x` is neither integer- nor floating-type.
TypeError: if `target_dtype` is neither integer- nor floating-type.
TypeError: if neither `x` nor `target_dtype` are integer-type.
"""
with ops.name_scope(name, values=[x]):
x = ops.convert_to_tensor(x, name="x")
if (not _is_integer_like_by_dtype(x.dtype)
and not x.dtype.is_floating):
raise TypeError("{}.dtype must be floating- or "
"integer-type.".format(x.dtype.name))
if (not _is_integer_like_by_dtype(target_dtype)
and not target_dtype.is_floating):
raise TypeError("target_dtype ({}) must be floating- or "
"integer-type.".format(target_dtype.name))
if (not _is_integer_like_by_dtype(x.dtype)
and not _is_integer_like_by_dtype(target_dtype)):
raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) "
"must be integer-type.".format(
x, x.dtype.name, target_dtype.name))
assertions = []
if assert_nonnegative:
assertions += [
check_ops.assert_non_negative(
x, message="Elements must be non-negative."),
]
if x.dtype.is_floating:
# Being here means _is_integer_like_by_dtype(target_dtype) = True.
# Since this check implies the magnitude check below, we need only it.
assertions += [
assert_integer_form(
x, int_dtype=target_dtype,
message="Elements must be {}-equivalent.".format(
target_dtype.name)),
]
else:
if (_largest_integer_by_dtype(x.dtype)
> _largest_integer_by_dtype(target_dtype)):
# Cast may lose integer precision.
assertions += [
check_ops.assert_less_equal(
x, _largest_integer_by_dtype(target_dtype),
message=("Elements cannot exceed {}.".format(
_largest_integer_by_dtype(target_dtype)))),
]
if (not assert_nonnegative and
(_smallest_integer_by_dtype(x.dtype)
< _smallest_integer_by_dtype(target_dtype))):
assertions += [
check_ops.assert_greater_equal(
x, _smallest_integer_by_dtype(target_dtype),
message=("Elements cannot be smaller than {}.".format(
_smallest_integer_by_dtype(target_dtype)))),
]
if not assertions:
return x
return control_flow_ops.with_dependencies(assertions, x)
def log_combinations(n, counts, name="log_combinations"):
"""Multinomial coefficient.
Given `n` and `counts`, where `counts` has last dimension `k`, we compute
the multinomial coefficient as:
```n! / sum_i n_i!```
where `i` runs over all `k` classes.
Args:
n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
outcomes.
counts: Floating-point `Tensor` broadcastable with `n`. This represents
counts in `k` classes, where `k` is the last dimension of the tensor.
name: A name for this operation (optional).
Returns:
`Tensor` representing the multinomial coefficient between `n` and `counts`.
"""
# First a bit about the number of ways counts could have come in:
# E.g. if counts = [1, 2], then this is 3 choose 2.
# In general, this is (sum counts)! / sum(counts!)
# The sum should be along the last dimension of counts. This is the
# "distribution" dimension. Here n a priori represents the sum of counts.
with ops.name_scope(name, values=[n, counts]):
n = ops.convert_to_tensor(n, name="n")
counts = ops.convert_to_tensor(counts, name="counts")
total_permutations = math_ops.lgamma(n + 1)
counts_factorial = math_ops.lgamma(counts + 1)
redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1])
return total_permutations - redundant_permutations
def matrix_diag_transform(matrix, transform=None, name=None):
"""Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
Create a trainable covariance defined by a Cholesky factor:
```python
# Transform network layer into 2 x 2 array.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
# Make the diagonal positive. If the upper triangle was zero, this would be a
# valid Cholesky factor.
chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
# LinearOperatorLowerTriangular ignores the upper triangle.
operator = LinearOperatorLowerTriangular(chol)
```
Example of heteroskedastic 2-D linear regression.
```python
# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
# Get a trainable mean.
mu = tf.contrib.layers.fully_connected(activations, 2)
# This is a fully trainable multivariate normal!
dist = tf.contrib.distributions.MVNCholesky(mu, chol)
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
loss = -1 * tf.reduce_mean(dist.log_prob(labels))
```
Args:
matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
equal.
transform: Element-wise function mapping `Tensors` to `Tensors`. To
be applied to the diagonal of `matrix`. If `None`, `matrix` is returned
unchanged. Defaults to `None`.
name: A name to give created ops.
Defaults to "matrix_diag_transform".
Returns:
A `Tensor` with same shape and `dtype` as `matrix`.
"""
with ops.name_scope(name, "matrix_diag_transform", [matrix]):
matrix = ops.convert_to_tensor(matrix, name="matrix")
if transform is None:
return matrix
# Replace the diag with transformed diag.
diag = array_ops.matrix_diag_part(matrix)
transformed_diag = transform(diag)
transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag)
return transformed_mat
def rotate_transpose(x, shift, name="rotate_transpose"):
"""Circularly moves dims left or right.
Effectively identical to:
```python
numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
```
When `validate_args=False` additional graph-runtime checks are
performed. These checks entail moving data from to GPU to CPU.
Example:
```python
x = tf.random_normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4].
rotate_transpose(x, -1).shape == [2, 3, 4, 1]
rotate_transpose(x, -2).shape == [3, 4, 1, 2]
rotate_transpose(x, 1).shape == [4, 1, 2, 3]
rotate_transpose(x, 2).shape == [3, 4, 1, 2]
rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1]
rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3]
```
Args:
x: `Tensor`.
shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
transpose right (shift>0).
name: Python `str`. The name to give this op.
Returns:
rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
Raises:
TypeError: if shift is not integer type.
"""
with ops.name_scope(name, values=[x, shift]):
x = ops.convert_to_tensor(x, name="x")
shift = ops.convert_to_tensor(shift, name="shift")
# We do not assign back to preserve constant-ness.
check_ops.assert_integer(shift)
shift_value_static = tensor_util.constant_value(shift)
ndims = x.get_shape().ndims
if ndims is not None and shift_value_static is not None:
if ndims < 2: return x
shift_value_static = np.sign(shift_value_static) * (
abs(shift_value_static) % ndims)
if shift_value_static == 0: return x
perm = np.roll(np.arange(ndims), shift_value_static)
return array_ops.transpose(x, perm=perm)
else:
# Consider if we always had a positive shift, and some specified
# direction.
# When shifting left we want the new array:
# last(x, n-shift) + first(x, shift)
# and if shifting right then we want:
# last(x, shift) + first(x, n-shift)
# Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
# Also, we can encode direction and shift as one: direction * shift.
# Combining these facts, we have:
# a = cond(shift<0, -shift, n-shift)
# last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
# Finally, we transform shift by modulo length so it can be specified
# independently from the array upon which it operates (like python).
ndims = array_ops.rank(x)
shift = array_ops.where(math_ops.less(shift, 0),
math_ops.mod(-shift, ndims),
ndims - math_ops.mod(shift, ndims))
first = math_ops.range(0, shift)
last = math_ops.range(shift, ndims)
perm = array_ops.concat([last, first], 0)
return array_ops.transpose(x, perm=perm)
def pick_vector(cond,
true_vector,
false_vector,
name="pick_vector"):
"""Picks possibly different length row `Tensor`s based on condition.
Value `Tensor`s should have exactly one dimension.
If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
`false_vector` is immediately returned. I.e., no graph nodes are created and
no validation happens.
Args:
cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
true_vector: `Tensor` of one dimension. Returned when cond is `True`.
false_vector: `Tensor` of one dimension. Returned when cond is `False`.
name: Python `str`. The name to give this op.
Example:
```python
pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 18)) # [10, 11]
pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 18)) # [15, 16, 17]
```
Returns:
true_or_false_vector: `Tensor`.
Raises:
TypeError: if `cond.dtype != tf.bool`
TypeError: if `cond` is not a constant and
`true_vector.dtype != false_vector.dtype`
"""
with ops.name_scope(name, values=(cond, true_vector, false_vector)):
cond = ops.convert_to_tensor(cond, name="cond")
if cond.dtype != dtypes.bool:
raise TypeError("%s.dtype=%s which is not %s" %
(cond, cond.dtype, dtypes.bool))
cond_value_static = tensor_util.constant_value(cond)
if cond_value_static is not None:
return true_vector if cond_value_static else false_vector
true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
if true_vector.dtype != false_vector.dtype:
raise TypeError(
"%s.dtype=%s does not match %s.dtype=%s"
% (true_vector, true_vector.dtype,
false_vector, false_vector.dtype))
n = array_ops.shape(true_vector)[0]
return array_ops.slice(
array_ops.concat([true_vector, false_vector], 0),
[array_ops.where(cond, 0, n)], [array_ops.where(cond, n, -1)])
def prefer_static_broadcast_shape(
shape1, shape2, name="prefer_static_broadcast_shape"):
"""Convenience function which statically broadcasts shape when possible.
Args:
shape1: `1-D` integer `Tensor`. Already converted to tensor!
shape2: `1-D` integer `Tensor`. Already converted to tensor!
name: A string name to prepend to created ops.
Returns:
The broadcast shape, either as `TensorShape` (if broadcast can be done
statically), or as a `Tensor`.
"""
with ops.name_scope(name, values=[shape1, shape2]):
def make_shape_tensor(x):
return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)
def get_tensor_shape(s):
if isinstance(s, tensor_shape.TensorShape):
return s
s_ = tensor_util.constant_value(make_shape_tensor(s))
if s_ is not None:
return tensor_shape.TensorShape(s_)
return None
def get_shape_tensor(s):
if not isinstance(s, tensor_shape.TensorShape):
return make_shape_tensor(s)
if s.is_fully_defined():
return make_shape_tensor(s.as_list())
raise ValueError("Cannot broadcast from partially "
"defined `TensorShape`.")
shape1_ = get_tensor_shape(shape1)
shape2_ = get_tensor_shape(shape2)
if shape1_ is not None and shape2_ is not None:
return array_ops.broadcast_static_shape(shape1_, shape2_)
shape1_ = get_shape_tensor(shape1)
shape2_ = get_shape_tensor(shape2)
return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
def prefer_static_rank(x):
"""Return static rank of tensor `x` if available, else `tf.rank(x)`.
Args:
x: `Tensor` (already converted).
Returns:
Numpy array (if static rank is obtainable), else `Tensor`.
"""
return prefer_static_value(array_ops.rank(x))
def prefer_static_shape(x):
"""Return static shape of tensor `x` if available, else `tf.shape(x)`.
Args:
x: `Tensor` (already converted).
Returns:
Numpy array (if static shape is obtainable), else `Tensor`.
"""
return prefer_static_value(array_ops.shape(x))
def prefer_static_value(x):
"""Return static value of tensor `x` if available, else `x`.
Args:
x: `Tensor` (already converted).
Returns:
Numpy array (if static value is obtainable), else `Tensor`.
"""
static_x = tensor_util.constant_value(x)
if static_x is not None:
return static_x
return x
def gen_new_seed(seed, salt):
"""Generate a new seed, from the given seed and salt."""
if seed is None:
return None
string = (str(seed) + salt).encode("utf-8")
return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
def fill_triangular(x, upper=False, name=None):
"""Creates a (batch of) triangular matrix from a vector of inputs.
Created matrix can be lower- or upper-triangular. (It is more efficient to
create the matrix as upper or lower, rather than transpose.)
Triangular matrix elements are filled in a clockwise spiral. See example,
below.
If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
Example:
```python
fill_triangular([1, 2, 3, 4, 5, 6])
# ==> [[4, 0, 0],
# [6, 5, 0],
# [3, 2, 1]]
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
# ==> [[1, 2, 3],
# [0, 5, 6],
# [0, 0, 4]]
```
For comparison, a pure numpy version of this function can be found in
`util_test.py`, function `_fill_triangular`.
Args:
x: `Tensor` representing lower (or upper) triangular elements.
upper: Python `bool` representing whether output matrix should be upper
triangular (`True`) or lower triangular (`False`, default).
name: Python `str`. The name to give this op.
Returns:
tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
Raises:
ValueError: if `x` cannot be mapped to a triangular matrix.
"""
with ops.name_scope(name, "fill_triangular", values=[x]):
x = ops.convert_to_tensor(x, name="x")
if x.shape.with_rank_at_least(1)[-1].value is not None:
# Formula derived by solving for n: m = n(n+1)/2.
m = np.int32(x.shape[-1].value)
n = np.sqrt(0.25 + 2. * m) - 0.5
if n != np.floor(n):
raise ValueError("Input right-most shape ({}) does not "
"correspond to a triangular matrix.".format(m))
n = np.int32(n)
static_final_shape = x.shape[:-1].concatenate([n, n])
else:
m = array_ops.shape(x)[-1]
# For derivation, see above. Casting automatically lops off the 0.5, so we
# omit it. We don't validate n is an integer because this has
# graph-execution cost; an error will be thrown from the reshape, below.
n = math_ops.cast(
math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)),
dtype=dtypes.int32)
static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
[None, None])
# We now concatenate the "tail" of `x` to `x` (and reverse one of them).
#
# We do this based on the insight that the input `x` provides `ceil(n/2)`
# rows of an `n x n` matrix, some of which will get zeroed out being on the
# wrong side of the diagonal. The first row will not get zeroed out at all,
# and we need `floor(n/2)` more rows, so the first is what we omit from
# `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
# rows provided by a reversed tail, it is exactly the other set of elements
# of the reversed tail which will be zeroed out for being on the wrong side
# of the diagonal further up/down the matrix. And, in doing-so, we've filled
# the triangular matrix in a clock-wise spiral pattern. Neat!
#
# Try it out in numpy:
# n = 3
# x = np.arange(n * (n + 1) / 2)
# m = x.shape[0]
# n = np.int32(np.sqrt(.25 + 2 * m) - .5)
# x_tail = x[(m - (n**2 - m)):]
# np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower
# # ==> array([[3, 4, 5],
# [5, 4, 3],
# [2, 1, 0]])
# np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper
# # ==> array([[0, 1, 2],
# [3, 4, 5],
# [5, 4, 3]])
#
# Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
# correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
# Furthermore observe that:
# m - (n**2 - m)
# = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
# = 2 (n**2 / 2 + n / 2) - n**2
# = n**2 + n - n**2
# = n
ndims = prefer_static_rank(x)
if upper:
x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
else:
x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
new_shape = (
static_final_shape.as_list()
if static_final_shape.is_fully_defined()
else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0))
x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape)
x = array_ops.matrix_band_part(
x,
num_lower=(0 if upper else -1),
num_upper=(-1 if upper else 0))
x.set_shape(static_final_shape)
return x
def fill_triangular_inverse(x, upper=False, name=None):
"""Creates a vector from a (batch of) triangular matrix.
The vector is created from the lower-triangular or upper-triangular portion
depending on the value of the parameter `upper`.
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
Example:
```python
fill_triangular_inverse(
[[4, 0, 0],
[6, 5, 0],
[3, 2, 1]])
# ==> [1, 2, 3, 4, 5, 6]
fill_triangular_inverse(
[[1, 2, 3],
[0, 5, 6],
[0, 0, 4]], upper=True)
# ==> [1, 2, 3, 4, 5, 6]
```
Args:
x: `Tensor` representing lower (or upper) triangular elements.
upper: Python `bool` representing whether output matrix should be upper
triangular (`True`) or lower triangular (`False`, default).
name: Python `str`. The name to give this op.
Returns:
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
(or upper) triangular elements from `x`.
"""
with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
x = ops.convert_to_tensor(x, name="x")
if x.shape.with_rank_at_least(2)[-1].value is not None:
n = np.int32(x.shape[-1].value)
m = np.int32((n * (n + 1)) // 2)
static_final_shape = x.shape[:-2].concatenate([m])
else:
n = array_ops.shape(x)[-1]
m = (n * (n + 1)) // 2
static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
[None])
ndims = prefer_static_rank(x)
if upper:
initial_elements = x[..., 0, :]
triangular_portion = x[..., 1:, :]
else:
initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
triangular_portion = x[..., :-1, :]
rotated_triangular_portion = array_ops.reverse(
array_ops.reverse(triangular_portion, axis=[ndims - 1]),
axis=[ndims - 2])
consolidated_matrix = triangular_portion + rotated_triangular_portion
end_sequence = array_ops.reshape(
consolidated_matrix,
array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
y.set_shape(static_final_shape)
return y
def tridiag(below=None, diag=None, above=None, name=None):
"""Creates a matrix with values set above, below, and on the diagonal.
Example:
```python
tridiag(below=[1., 2., 3.],
diag=[4., 5., 6., 7.],
above=[8., 9., 10.])
# ==> array([[ 4., 8., 0., 0.],
# [ 1., 5., 9., 0.],
# [ 0., 2., 6., 10.],
# [ 0., 0., 3., 7.]], dtype=float32)
```
Warning: This Op is intended for convenience, not efficiency.
Args:
below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below
diagonal part. `None` is logically equivalent to `below = 0`.
diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal
part. `None` is logically equivalent to `diag = 0`.
above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above
diagonal part. `None` is logically equivalent to `above = 0`.
name: Python `str`. The name to give this op.
Returns:
tridiag: `Tensor` with values set above, below and on the diagonal.
Raises:
ValueError: if all inputs are `None`.
"""
def _pad(x):
"""Prepends and appends a zero to every vector in a batch of vectors."""
shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0)
z = array_ops.zeros(shape, dtype=x.dtype)
return array_ops.concat([z, x, z], axis=-1)
def _add(*x):
"""Adds list of Tensors, ignoring `None`."""
s = None
for y in x:
if y is None:
continue
elif s is None:
s = y
else:
s += y
if s is None:
raise ValueError("Must specify at least one of `below`, `diag`, `above`.")
return s
with ops.name_scope(name, "tridiag", [below, diag, above]):
if below is not None:
below = ops.convert_to_tensor(below, name="below")
below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:]
if diag is not None:
diag = ops.convert_to_tensor(diag, name="diag")
diag = array_ops.matrix_diag(diag)
if above is not None:
above = ops.convert_to_tensor(above, name="above")
above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1]
# TODO(jvdillon): Consider using scatter_nd instead of creating three full
# matrices.
return _add(below, diag, above)
def reduce_weighted_logsumexp(
logx,
w=None,
axis=None,
keep_dims=False,
return_sign=False,
name=None):
"""Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
If all weights `w` are known to be positive, it is more efficient to directly
use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more
efficient than `du.reduce_weighted_logsumexp(logx, w)`.
Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keep_dims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
tensor with a single element is returned.
This function is more numerically stable than log(sum(w * exp(input))). It
avoids overflows caused by taking the exp of large inputs and underflows
caused by taking the log of small inputs.
For example:
```python
x = tf.constant([[0., 0, 0],
[0, 0, 0]])
w = tf.constant([[-1., 1, 1],
[1, 1, 1]])
du.reduce_weighted_logsumexp(x, w)
# ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
du.reduce_weighted_logsumexp(x, w, axis=0)
# ==> [log(-1+1), log(1+1), log(1+1)]
du.reduce_weighted_logsumexp(x, w, axis=1)
# ==> [log(-1+1+1), log(1+1+1)]
du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
# ==> [[log(-1+1+1)], [log(1+1+1)]]
du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
# ==> log(-1+5)
```
Args:
logx: The tensor to reduce. Should have numeric type.
w: The weight tensor. Should have numeric type identical to `logx`.
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
keep_dims: If true, retains reduced dimensions with length 1.
return_sign: If `True`, returns the sign of the result.
name: A name for the operation (optional).
Returns:
lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
sign: (Optional) The sign of `sum(weight * exp(x))`.
"""
with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
logx = ops.convert_to_tensor(logx, name="logx")
if w is None:
lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
if return_sign:
sgn = array_ops.ones_like(lswe)
return lswe, sgn
return lswe
w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
log_absw_x = logx + math_ops.log(math_ops.abs(w))
max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
# If the largest element is `-inf` or `inf` then we don't bother subtracting
# off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
# this is ok follows from the fact that we're actually free to subtract any
# value we like, so long as we add it back after taking the `log(sum(...))`.
max_log_absw_x = array_ops.where(
math_ops.is_inf(max_log_absw_x),
array_ops.zeros_like(max_log_absw_x),
max_log_absw_x)
wx_over_max_absw_x = (
math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
sum_wx_over_max_absw_x = math_ops.reduce_sum(
wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
if not keep_dims:
max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
sgn = math_ops.sign(sum_wx_over_max_absw_x)
lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x)
if return_sign:
return lswe, sgn
return lswe
# TODO(jvdillon): Merge this test back into:
# tensorflow/python/ops/softplus_op_test.py
# once TF core is accepting new ops.
def softplus_inverse(x, name=None):
"""Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
Mathematically this op is equivalent to:
```none
softplus_inverse = log(exp(x) - 1.)
```
Args:
x: `Tensor`. Non-negative (not enforced), floating-point.
name: A name for the operation (optional).
Returns:
`Tensor`. Has the same type/shape as input `x`.
"""
with ops.name_scope(name, "softplus_inverse", values=[x]):
x = ops.convert_to_tensor(x, name="x")
# We begin by deriving a more numerically stable softplus_inverse:
# x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
# ==> exp{x} = 1 + exp{y} (1)
# ==> y = Log[exp{x} - 1] (2)
# = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
# = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
# = Log[1 - exp{-x}] + x (3)
# (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
# For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
# be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
#
# In addition to the numerically stable derivation above, we clamp
# small/large values to be congruent with the logic in:
# tensorflow/core/kernels/softplus_op.h
#
# Finally, we set the input to one whenever the input is too large or too
# small. This ensures that no unchosen codepath is +/- inf. This is
# necessary to ensure the gradient doesn't get NaNs. Recall that the
# gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
# thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
# to overwrite `x` with ones only when we will never actually use this
# value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
is_too_small = math_ops.less(x, np.exp(threshold))
is_too_large = math_ops.greater(x, -threshold)
too_small_value = math_ops.log(x)
too_large_value = x
# This `where` will ultimately be a NOP because we won't select this
# codepath whenever we used the surrogate `ones_like`.
x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large),
array_ops.ones_like(x), x)
y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x))
return array_ops.where(is_too_small, too_small_value,
array_ops.where(is_too_large, too_large_value, y))
# TODO(b/35290280): Add unit-tests.
def dimension_size(x, axis):
"""Returns the size of a specific dimension."""
# Since tf.gather isn't "constant-in, constant-out", we must first check the
# static shape or fallback to dynamic shape.
s = x.shape.with_rank_at_least(np.abs(axis))[axis].value
if s is not None:
return s
return array_ops.shape(x)[axis]
def process_quadrature_grid_and_probs(
quadrature_grid_and_probs, dtype, validate_args, name=None):
"""Validates quadrature grid, probs or computes them as necessary.
Args:
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
representing the sample points and the corresponding (possibly
normalized) weight. When `None`, defaults to:
`np.polynomial.hermite.hermgauss(deg=8)`.
dtype: The expected `dtype` of `grid` and `probs`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
name: Python `str` name prefixed to Ops created by this class.
Returns:
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
representing the sample points and the corresponding (possibly
normalized) weight.
Raises:
ValueError: if `quadrature_grid_and_probs is not None` and
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
"""
with ops.name_scope(name, "process_quadrature_grid_and_probs",
[quadrature_grid_and_probs]):
if quadrature_grid_and_probs is None:
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
grid = grid.astype(dtype.as_numpy_dtype)
probs = probs.astype(dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
return grid, probs
grid, probs = tuple(quadrature_grid_and_probs)
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
probs = ops.convert_to_tensor(probs, name="unnormalized_probs",
dtype=dtype)
probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
def _static_event_size(x):
"""Returns the static size of a specific dimension or `None`."""
return x.shape.with_rank_at_least(1)[-1].value
m, n = _static_event_size(probs), _static_event_size(grid)
if m is not None and n is not None:
if m != n:
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
"same-length zero-th-dimension `Tensor`s "
"(saw lengths {}, {})".format(m, n))
elif validate_args:
assertions = [
check_ops.assert_equal(
dimension_size(probs, axis=-1),
dimension_size(grid, axis=-1),
message=("`quadrature_grid_and_probs` must be a `tuple` of "
"same-length zero-th-dimension `Tensor`s")),
]
with ops.control_dependencies(assertions):
grid = array_ops.identity(grid)
probs = array_ops.identity(probs)
return grid, probs
def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
"""Pads `value` to the front and/or back of a `Tensor` dim, `count` times.
Args:
x: `Tensor` input.
axis: Scalar `int`-like `Tensor` representing the single dimension to pad.
(Negative indexing is supported.)
front: Python `bool`; if `True` the beginning of the `axis` dimension is
padded with `value`, `count` times. If `False` no front padding is made.
back: Python `bool`; if `True` the end of the `axis` dimension is
padded with `value`, `count` times. If `False` no end padding is made.
value: Scalar `int`-like `Tensor` representing the actual value added to the
front and/or back of the `axis` dimension of `x`.
count: Scalar `int`-like `Tensor` representing number of elements added to
the front and/or back of the `axis` dimension of `x`. E.g., if
`front = back = True` then `2 * count` elements are added.
name: Python `str` name prefixed to Ops created by this function.
Returns:
pad: The padded version of input `x`.
Raises:
ValueError: if both `front` and `back` are `False`.
TypeError: if `count` is not `int`-like.
"""
with ops.name_scope(name, "pad", [x, value, count]):
x = ops.convert_to_tensor(x, name="x")
value = ops.convert_to_tensor(value, dtype=x.dtype, name="value")
count = ops.convert_to_tensor(count, name="count")
if not count.dtype.is_integer:
raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format(
count.dtype.name))
if not front and not back:
raise ValueError("At least one of `front`, `back` must be `True`.")
ndims = (x.shape.ndims if x.shape.ndims is not None
else array_ops.rank(x, name="ndims"))
axis = ops.convert_to_tensor(axis, name="axis")
axis_ = tensor_util.constant_value(axis)
if axis_ is not None:
axis = axis_
if axis < 0:
axis = ndims + axis
count_ = tensor_util.constant_value(count)
if axis_ >= 0 or x.shape.ndims is not None:
head = x.shape[:axis]
middle = tensor_shape.TensorShape(
None if count_ is None
else (x.shape[axis] + count_ * (front + back)))
tail = x.shape[axis+1:]
final_shape = head.concatenate(middle.concatenate(tail))
else:
final_shape = None
else:
axis = array_ops.where(axis < 0, ndims + axis, axis)
final_shape = None
x = array_ops.pad(
x,
paddings=array_ops.one_hot(
indices=array_ops.stack([axis if front else -1,
axis if back else -1]),
depth=ndims,
axis=0,
on_value=count,
dtype=dtypes.int32),
constant_values=value)
if final_shape is not None:
x.set_shape(final_shape)
return x
def parent_frame_arguments():
"""Returns parent frame arguments.
When called inside a function, returns a dictionary with the caller's function
arguments. These are positional arguments and keyword arguments (**kwargs),
while variable arguments (*varargs) are excluded.
When called at global scope, this will return an empty dictionary, since there
are no arguments.
WARNING: If caller function argument names are overloaded before invoking
this method, then values will reflect the overloaded value. For this reason,
we recommend calling `parent_frame_arguments` at the beginning of the
function.
"""
# All arguments and the names used for *varargs, and **kwargs
arg_names, variable_arg_name, keyword_arg_name, local_vars = (
tf_inspect._inspect.getargvalues( # pylint: disable=protected-access
# Get the first frame of the caller of this method.
tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access
# Remove the *varargs, and flatten the **kwargs. Both are
# nested lists.
local_vars.pop(variable_arg_name, {})
keyword_args = local_vars.pop(keyword_arg_name, {})
final_args = {}
# Copy over arguments and their values. In general, local_vars
# may contain more than just the arguments, since this method
# can be called anywhere in a function.
for arg_name in arg_names:
final_args[arg_name] = local_vars.pop(arg_name)
final_args.update(keyword_args)
return final_args
class AppendDocstring(object):
"""Helper class to promote private subclass docstring to public counterpart.
Example:
```python
class TransformedDistribution(Distribution):
@distribution_util.AppendDocstring(
additional_note="A special note!",
kwargs_dict={"foo": "An extra arg."})
def _prob(self, y, foo=None):
pass
```
In this case, the `AppendDocstring` decorator appends the `additional_note` to
the docstring of `prob` (not `_prob`) and adds a new `kwargs`
section with each dictionary item as a bullet-point.
For a more detailed example, see `TransformedDistribution`.
"""
def __init__(self, additional_note="", kwargs_dict=None):
"""Initializes the AppendDocstring object.
Args:
additional_note: Python string added as additional docstring to public
version of function.
kwargs_dict: Python string/string dictionary representing
specific kwargs expanded from the **kwargs input.
Raises:
ValueError: if kwargs_dict.key contains whitespace.
ValueError: if kwargs_dict.value contains newlines.
"""
self._additional_note = additional_note
if kwargs_dict:
bullets = []
for key in sorted(kwargs_dict.keys()):
value = kwargs_dict[key]
if any(x.isspace() for x in key):
raise ValueError(
"Parameter name \"%s\" contains whitespace." % key)
value = value.lstrip()
if "\n" in value:
raise ValueError(
"Parameter description for \"%s\" contains newlines." % key)
bullets.append("* `%s`: %s" % (key, value))
self._additional_note += ("\n\n##### `kwargs`:\n\n" +
"\n".join(bullets))
def __call__(self, fn):
@functools.wraps(fn)
def _fn(*args, **kwargs):
return fn(*args, **kwargs)
if _fn.__doc__ is None:
_fn.__doc__ = self._additional_note
else:
_fn.__doc__ += "\n%s" % self._additional_note
return _fn