# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for probability distributions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import hashlib import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.util import tf_inspect def assert_integer_form( x, data=None, summarize=None, message=None, int_dtype=None, name="assert_integer_form"): """Assert that x has integer components (or floats equal to integers). Args: x: Floating-point `Tensor` data: The tensors to print out if the condition is `False`. Defaults to error message and first few entries of `x` and `y`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. int_dtype: A `tf.dtype` used to cast the float to. The default (`None`) implies the smallest possible signed int will be used for casting. name: A name for this operation (optional). Returns: Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`. """ with ops.name_scope(name, values=[x, data]): x = ops.convert_to_tensor(x, name="x") if x.dtype.is_integer: return control_flow_ops.no_op() message = message or "{} has non-integer components".format(x) if int_dtype is None: try: int_dtype = { dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64, }[x.dtype.base_dtype] except KeyError: raise TypeError("Unrecognized type {}".format(x.dtype.name)) return check_ops.assert_equal( x, math_ops.cast(math_ops.cast(x, int_dtype), x.dtype), data=data, summarize=summarize, message=message, name=name) def assert_symmetric(matrix): matrix_t = array_ops.matrix_transpose(matrix) return control_flow_ops.with_dependencies( [check_ops.assert_equal(matrix, matrix_t)], matrix) def embed_check_nonnegative_integer_form( x, name="embed_check_nonnegative_integer_form"): """Assert x is a non-negative tensor, and optionally of integers.""" with ops.name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") assertions = [ check_ops.assert_non_negative( x, message="'{}' must be non-negative.".format(x)), ] if not x.dtype.is_integer: assertions += [ assert_integer_form( x, message="'{}' cannot contain fractional components.".format( x)), ] return control_flow_ops.with_dependencies(assertions, x) def same_dynamic_shape(a, b): """Returns whether a and b have the same dynamic shape. Args: a: `Tensor` b: `Tensor` Returns: `bool` `Tensor` representing if both tensors have the same shape. """ a = ops.convert_to_tensor(a, name="a") b = ops.convert_to_tensor(b, name="b") # Here we can't just do math_ops.equal(a.shape, b.shape), since # static shape inference may break the equality comparison between # shape(a) and shape(b) in math_ops.equal. def all_shapes_equal(): return math_ops.reduce_all(math_ops.equal( array_ops.concat([array_ops.shape(a), array_ops.shape(b)], 0), array_ops.concat([array_ops.shape(b), array_ops.shape(a)], 0))) # One of the shapes isn't fully defined, so we need to use the dynamic # shape. return control_flow_ops.cond( math_ops.equal(array_ops.rank(a), array_ops.rank(b)), all_shapes_equal, lambda: constant_op.constant(False)) def maybe_get_static_value(x, dtype=None): """Helper which tries to return a static value. Given `x`, extract it's value statically, optionally casting to a specific dtype. If this is not possible, None is returned. Args: x: `Tensor` for which to extract a value statically. dtype: Optional dtype to cast to. Returns: Statically inferred value if possible, otherwise None. """ if x is None: return x try: # This returns an np.ndarray. x_ = tensor_util.constant_value(x) except TypeError: x_ = x if x_ is None or dtype is None: return x_ return np.array(x_, dtype) def get_logits_and_probs(logits=None, probs=None, multidimensional=False, validate_args=False, name="get_logits_and_probs"): """Converts logit to probabilities (or vice-versa), and returns both. Args: logits: Floating-point `Tensor` representing log-odds. probs: Floating-point `Tensor` representing probabilities. multidimensional: Python `bool`, default `False`. If `True`, represents whether the last dimension of `logits` or `probs`, a `[N1, N2, ... k]` dimensional tensor, representing the logit or probability of `shape[-1]` classes. validate_args: Python `bool`, default `False`. When `True`, either assert `0 <= probs <= 1` (if not `multidimensional`) or that the last dimension of `probs` sums to one. name: A name for this operation (optional). Returns: logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or `1`, then the corresponding entry in the returned logit will be `-Inf` and `Inf` respectively. Raises: ValueError: if neither `probs` nor `logits` were passed in, or both were. """ with ops.name_scope(name, values=[probs, logits]): if (probs is None) == (logits is None): raise ValueError("Must pass probs or logits, but not both.") if probs is None: logits = ops.convert_to_tensor(logits, name="logits") if not logits.dtype.is_floating: raise TypeError("logits must having floating type.") # We can early return since we constructed probs and therefore know # they're valid. if multidimensional: if validate_args: logits = embed_check_categorical_event_shape(logits) return logits, nn.softmax(logits, name="probs") return logits, math_ops.sigmoid(logits, name="probs") probs = ops.convert_to_tensor(probs, name="probs") if not probs.dtype.is_floating: raise TypeError("probs must having floating type.") if validate_args: with ops.name_scope("validate_probs"): one = constant_op.constant(1., probs.dtype) dependencies = [check_ops.assert_non_negative(probs)] if multidimensional: probs = embed_check_categorical_event_shape(probs) dependencies += [ check_ops.assert_near( math_ops.reduce_sum(probs, -1), one, message="probs does not sum to 1.") ] else: dependencies += [check_ops.assert_less_equal( probs, one, message="probs has components greater than 1.")] probs = control_flow_ops.with_dependencies(dependencies, probs) with ops.name_scope("logits"): if multidimensional: # Here we don't compute the multidimensional case, in a manner # consistent with respect to the unidimensional case. We do so # following the TF convention. Typically, you might expect to see # logits = log(probs) - log(probs[pivot]). A side-effect of # being consistent with the TF approach is that the unidimensional case # implicitly handles the second dimension but the multidimensional case # explicitly keeps the pivot dimension. return math_ops.log(probs), probs return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs def _is_known_unsigned_by_dtype(dt): """Helper returning True if dtype is known to be unsigned.""" return { dtypes.bool: True, dtypes.uint8: True, dtypes.uint16: True, }.get(dt.base_dtype, False) def _is_known_signed_by_dtype(dt): """Helper returning True if dtype is known to be signed.""" return { dtypes.float16: True, dtypes.float32: True, dtypes.float64: True, dtypes.int8: True, dtypes.int16: True, dtypes.int32: True, dtypes.int64: True, }.get(dt.base_dtype, False) def _is_known_dtype(dt): """Helper returning True if dtype is known.""" return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt) def _largest_integer_by_dtype(dt): """Helper returning the largest integer exactly representable by dtype.""" if not _is_known_dtype(dt): raise TypeError("Unrecognized dtype: {}".format(dt.name)) if dt.is_floating: return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1)) if dt.is_integer: return np.iinfo(dt.as_numpy_dtype).max if dt.base_dtype == dtypes.bool: return int(1) # We actually can't land here but keep the case for completeness. raise TypeError("Unrecognized dtype: {}".format(dt.name)) def _smallest_integer_by_dtype(dt): """Helper returning the smallest integer exactly representable by dtype.""" if not _is_known_dtype(dt): raise TypeError("Unrecognized dtype: {}".format(dt.name)) if _is_known_unsigned_by_dtype(dt): return 0 return -1 * _largest_integer_by_dtype(dt) def _is_integer_like_by_dtype(dt): """Helper returning True if dtype.is_integer or is `bool`.""" if not _is_known_dtype(dt): raise TypeError("Unrecognized dtype: {}".format(dt.name)) return dt.is_integer or dt.base_dtype == dtypes.bool def embed_check_categorical_event_shape( categorical_param, name="embed_check_categorical_event_shape"): """Embeds checks that categorical distributions don't have too many classes. A categorical-type distribution is one which, e.g., returns the class label rather than a one-hot encoding. E.g., `Categorical(probs)`. Since distributions output samples in the same dtype as the parameters, we must ensure that casting doesn't lose precision. That is, the `parameter.dtype` implies a maximum number of classes. However, since shape is `int32` and categorical variables are presumed to be indexes into a `Tensor`, we must also ensure that the number of classes is no larger than the largest possible `int32` index, i.e., `2**31-1`. In other words the number of classes, `K`, must satisfy the following condition: ```python K <= min( int(2**31 - 1), # Largest float as an index. { dtypes.float16: int(2**11), # Largest int as a float16. dtypes.float32: int(2**24), dtypes.float64: int(2**53), }.get(categorical_param.dtype.base_dtype, 0)) ``` Args: categorical_param: Floating-point `Tensor` representing parameters of distribution over categories. The rightmost shape is presumed to be the number of categories. name: A name for this operation (optional). Returns: categorical_param: Input `Tensor` with appropriate assertions embedded. Raises: TypeError: if `categorical_param` has an unknown `dtype`. ValueError: if we can statically identify `categorical_param` as being too large (for being closed under int32/float casting). """ with ops.name_scope(name, values=[categorical_param]): x = ops.convert_to_tensor(categorical_param, name="categorical_param") # The size must not exceed both of: # - The largest possible int32 (since categorical values are presumed to be # indexes into a Tensor). # - The largest possible integer exactly representable under the given # floating-point dtype (since we need to cast to/from). # # The chosen floating-point thresholds are 2**(1 + mantissa_bits). # For more details, see: # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation x_dtype = x.dtype.base_dtype max_event_size = (_largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0) if max_event_size is 0: raise TypeError("Unable to validate size of unrecognized dtype " "({}).".format(x_dtype.name)) try: x_shape_static = x.get_shape().with_rank_at_least(1) except ValueError: raise ValueError("A categorical-distribution parameter must have " "at least 1 dimension.") if x_shape_static[-1].value is not None: event_size = x_shape_static[-1].value if event_size < 2: raise ValueError("A categorical-distribution parameter must have at " "least 2 events.") if event_size > max_event_size: raise ValueError( "Number of classes exceeds `dtype` precision, i.e., " "{} implies shape ({}) cannot exceed {}.".format( x_dtype.name, event_size, max_event_size)) return x else: event_size = array_ops.shape(x, name="x_shape")[-1] return control_flow_ops.with_dependencies([ check_ops.assert_rank_at_least( x, 1, message=("A categorical-distribution parameter must have " "at least 1 dimension.")), check_ops.assert_greater_equal( array_ops.shape(x)[-1], 2, message=("A categorical-distribution parameter must have at " "least 2 events.")), check_ops.assert_less_equal( event_size, max_event_size, message="Number of classes exceeds `dtype` precision, " "i.e., {} dtype cannot exceed {} shape.".format( x_dtype.name, max_event_size)), ], x) def embed_check_integer_casting_closed( x, target_dtype, assert_nonnegative=True, name="embed_check_casting_closed"): """Ensures integers remain unaffected despite casting to/from int/float types. Example integer-types: `uint8`, `int32`, `bool`. Example floating-types: `float32`, `float64`. The largest possible integer representable by an IEEE754 floating-point is `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have integer-form values can be cast to some other type without loss of precision. The smallest representable integer is the negative of the largest representable integer, except for types: `uint8`, `uint16`, `bool`. For these types, the smallest representable integer is `0`. Args: x: `Tensor` representing integer-form values. target_dtype: TF `dtype` under which `x` should have identical values. assert_nonnegative: `bool` indicating `x` should contain nonnegative values. name: A name for this operation (optional). Returns: x: Input `Tensor` with appropriate assertions embedded. Raises: TypeError: if `x` is neither integer- nor floating-type. TypeError: if `target_dtype` is neither integer- nor floating-type. TypeError: if neither `x` nor `target_dtype` are integer-type. """ with ops.name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating): raise TypeError("{}.dtype must be floating- or " "integer-type.".format(x.dtype.name)) if (not _is_integer_like_by_dtype(target_dtype) and not target_dtype.is_floating): raise TypeError("target_dtype ({}) must be floating- or " "integer-type.".format(target_dtype.name)) if (not _is_integer_like_by_dtype(x.dtype) and not _is_integer_like_by_dtype(target_dtype)): raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) " "must be integer-type.".format( x, x.dtype.name, target_dtype.name)) assertions = [] if assert_nonnegative: assertions += [ check_ops.assert_non_negative( x, message="Elements must be non-negative."), ] if x.dtype.is_floating: # Being here means _is_integer_like_by_dtype(target_dtype) = True. # Since this check implies the magnitude check below, we need only it. assertions += [ assert_integer_form( x, int_dtype=target_dtype, message="Elements must be {}-equivalent.".format( target_dtype.name)), ] else: if (_largest_integer_by_dtype(x.dtype) > _largest_integer_by_dtype(target_dtype)): # Cast may lose integer precision. assertions += [ check_ops.assert_less_equal( x, _largest_integer_by_dtype(target_dtype), message=("Elements cannot exceed {}.".format( _largest_integer_by_dtype(target_dtype)))), ] if (not assert_nonnegative and (_smallest_integer_by_dtype(x.dtype) < _smallest_integer_by_dtype(target_dtype))): assertions += [ check_ops.assert_greater_equal( x, _smallest_integer_by_dtype(target_dtype), message=("Elements cannot be smaller than {}.".format( _smallest_integer_by_dtype(target_dtype)))), ] if not assertions: return x return control_flow_ops.with_dependencies(assertions, x) def log_combinations(n, counts, name="log_combinations"): """Multinomial coefficient. Given `n` and `counts`, where `counts` has last dimension `k`, we compute the multinomial coefficient as: ```n! / sum_i n_i!``` where `i` runs over all `k` classes. Args: n: Floating-point `Tensor` broadcastable with `counts`. This represents `n` outcomes. counts: Floating-point `Tensor` broadcastable with `n`. This represents counts in `k` classes, where `k` is the last dimension of the tensor. name: A name for this operation (optional). Returns: `Tensor` representing the multinomial coefficient between `n` and `counts`. """ # First a bit about the number of ways counts could have come in: # E.g. if counts = [1, 2], then this is 3 choose 2. # In general, this is (sum counts)! / sum(counts!) # The sum should be along the last dimension of counts. This is the # "distribution" dimension. Here n a priori represents the sum of counts. with ops.name_scope(name, values=[n, counts]): n = ops.convert_to_tensor(n, name="n") counts = ops.convert_to_tensor(counts, name="counts") total_permutations = math_ops.lgamma(n + 1) counts_factorial = math_ops.lgamma(counts + 1) redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1]) return total_permutations - redundant_permutations def matrix_diag_transform(matrix, transform=None, name=None): """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. Create a trainable covariance defined by a Cholesky factor: ```python # Transform network layer into 2 x 2 array. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) # Make the diagonal positive. If the upper triangle was zero, this would be a # valid Cholesky factor. chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) # LinearOperatorLowerTriangular ignores the upper triangle. operator = LinearOperatorLowerTriangular(chol) ``` Example of heteroskedastic 2-D linear regression. ```python # Get a trainable Cholesky factor. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) # Get a trainable mean. mu = tf.contrib.layers.fully_connected(activations, 2) # This is a fully trainable multivariate normal! dist = tf.contrib.distributions.MVNCholesky(mu, chol) # Standard log loss. Minimizing this will "train" mu and chol, and then dist # will be a distribution predicting labels as multivariate Gaussians. loss = -1 * tf.reduce_mean(dist.log_prob(labels)) ``` Args: matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are equal. transform: Element-wise function mapping `Tensors` to `Tensors`. To be applied to the diagonal of `matrix`. If `None`, `matrix` is returned unchanged. Defaults to `None`. name: A name to give created ops. Defaults to "matrix_diag_transform". Returns: A `Tensor` with same shape and `dtype` as `matrix`. """ with ops.name_scope(name, "matrix_diag_transform", [matrix]): matrix = ops.convert_to_tensor(matrix, name="matrix") if transform is None: return matrix # Replace the diag with transformed diag. diag = array_ops.matrix_diag_part(matrix) transformed_diag = transform(diag) transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag) return transformed_mat def rotate_transpose(x, shift, name="rotate_transpose"): """Circularly moves dims left or right. Effectively identical to: ```python numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift)) ``` When `validate_args=False` additional graph-runtime checks are performed. These checks entail moving data from to GPU to CPU. Example: ```python x = tf.random_normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4]. rotate_transpose(x, -1).shape == [2, 3, 4, 1] rotate_transpose(x, -2).shape == [3, 4, 1, 2] rotate_transpose(x, 1).shape == [4, 1, 2, 3] rotate_transpose(x, 2).shape == [3, 4, 1, 2] rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1] rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3] ``` Args: x: `Tensor`. shift: `Tensor`. Number of dimensions to transpose left (shift<0) or transpose right (shift>0). name: Python `str`. The name to give this op. Returns: rotated_x: Input `Tensor` with dimensions circularly rotated by shift. Raises: TypeError: if shift is not integer type. """ with ops.name_scope(name, values=[x, shift]): x = ops.convert_to_tensor(x, name="x") shift = ops.convert_to_tensor(shift, name="shift") # We do not assign back to preserve constant-ness. check_ops.assert_integer(shift) shift_value_static = tensor_util.constant_value(shift) ndims = x.get_shape().ndims if ndims is not None and shift_value_static is not None: if ndims < 2: return x shift_value_static = np.sign(shift_value_static) * ( abs(shift_value_static) % ndims) if shift_value_static == 0: return x perm = np.roll(np.arange(ndims), shift_value_static) return array_ops.transpose(x, perm=perm) else: # Consider if we always had a positive shift, and some specified # direction. # When shifting left we want the new array: # last(x, n-shift) + first(x, shift) # and if shifting right then we want: # last(x, shift) + first(x, n-shift) # Observe that last(a) == slice(a, n) and first(a) == slice(0, a). # Also, we can encode direction and shift as one: direction * shift. # Combining these facts, we have: # a = cond(shift<0, -shift, n-shift) # last(x, n-a) + first(x, a) == x[a:n] + x[0:a] # Finally, we transform shift by modulo length so it can be specified # independently from the array upon which it operates (like python). ndims = array_ops.rank(x) shift = array_ops.where(math_ops.less(shift, 0), math_ops.mod(-shift, ndims), ndims - math_ops.mod(shift, ndims)) first = math_ops.range(0, shift) last = math_ops.range(shift, ndims) perm = array_ops.concat([last, first], 0) return array_ops.transpose(x, perm=perm) def pick_vector(cond, true_vector, false_vector, name="pick_vector"): """Picks possibly different length row `Tensor`s based on condition. Value `Tensor`s should have exactly one dimension. If `cond` is a python Boolean or `tf.constant` then either `true_vector` or `false_vector` is immediately returned. I.e., no graph nodes are created and no validation happens. Args: cond: `Tensor`. Must have `dtype=tf.bool` and be scalar. true_vector: `Tensor` of one dimension. Returned when cond is `True`. false_vector: `Tensor` of one dimension. Returned when cond is `False`. name: Python `str`. The name to give this op. Example: ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 18)) # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 18)) # [15, 16, 17] ``` Returns: true_or_false_vector: `Tensor`. Raises: TypeError: if `cond.dtype != tf.bool` TypeError: if `cond` is not a constant and `true_vector.dtype != false_vector.dtype` """ with ops.name_scope(name, values=(cond, true_vector, false_vector)): cond = ops.convert_to_tensor(cond, name="cond") if cond.dtype != dtypes.bool: raise TypeError("%s.dtype=%s which is not %s" % (cond, cond.dtype, dtypes.bool)) cond_value_static = tensor_util.constant_value(cond) if cond_value_static is not None: return true_vector if cond_value_static else false_vector true_vector = ops.convert_to_tensor(true_vector, name="true_vector") false_vector = ops.convert_to_tensor(false_vector, name="false_vector") if true_vector.dtype != false_vector.dtype: raise TypeError( "%s.dtype=%s does not match %s.dtype=%s" % (true_vector, true_vector.dtype, false_vector, false_vector.dtype)) n = array_ops.shape(true_vector)[0] return array_ops.slice( array_ops.concat([true_vector, false_vector], 0), [array_ops.where(cond, 0, n)], [array_ops.where(cond, n, -1)]) def prefer_static_broadcast_shape( shape1, shape2, name="prefer_static_broadcast_shape"): """Convenience function which statically broadcasts shape when possible. Args: shape1: `1-D` integer `Tensor`. Already converted to tensor! shape2: `1-D` integer `Tensor`. Already converted to tensor! name: A string name to prepend to created ops. Returns: The broadcast shape, either as `TensorShape` (if broadcast can be done statically), or as a `Tensor`. """ with ops.name_scope(name, values=[shape1, shape2]): def make_shape_tensor(x): return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) def get_tensor_shape(s): if isinstance(s, tensor_shape.TensorShape): return s s_ = tensor_util.constant_value(make_shape_tensor(s)) if s_ is not None: return tensor_shape.TensorShape(s_) return None def get_shape_tensor(s): if not isinstance(s, tensor_shape.TensorShape): return make_shape_tensor(s) if s.is_fully_defined(): return make_shape_tensor(s.as_list()) raise ValueError("Cannot broadcast from partially " "defined `TensorShape`.") shape1_ = get_tensor_shape(shape1) shape2_ = get_tensor_shape(shape2) if shape1_ is not None and shape2_ is not None: return array_ops.broadcast_static_shape(shape1_, shape2_) shape1_ = get_shape_tensor(shape1) shape2_ = get_shape_tensor(shape2) return array_ops.broadcast_dynamic_shape(shape1_, shape2_) def prefer_static_rank(x): """Return static rank of tensor `x` if available, else `tf.rank(x)`. Args: x: `Tensor` (already converted). Returns: Numpy array (if static rank is obtainable), else `Tensor`. """ return prefer_static_value(array_ops.rank(x)) def prefer_static_shape(x): """Return static shape of tensor `x` if available, else `tf.shape(x)`. Args: x: `Tensor` (already converted). Returns: Numpy array (if static shape is obtainable), else `Tensor`. """ return prefer_static_value(array_ops.shape(x)) def prefer_static_value(x): """Return static value of tensor `x` if available, else `x`. Args: x: `Tensor` (already converted). Returns: Numpy array (if static value is obtainable), else `Tensor`. """ static_x = tensor_util.constant_value(x) if static_x is not None: return static_x return x def gen_new_seed(seed, salt): """Generate a new seed, from the given seed and salt.""" if seed is None: return None string = (str(seed) + salt).encode("utf-8") return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF def fill_triangular(x, upper=False, name=None): """Creates a (batch of) triangular matrix from a vector of inputs. Created matrix can be lower- or upper-triangular. (It is more efficient to create the matrix as upper or lower, rather than transpose.) Triangular matrix elements are filled in a clockwise spiral. See example, below. If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: ```python fill_triangular([1, 2, 3, 4, 5, 6]) # ==> [[4, 0, 0], # [6, 5, 0], # [3, 2, 1]] fill_triangular([1, 2, 3, 4, 5, 6], upper=True) # ==> [[1, 2, 3], # [0, 5, 6], # [0, 0, 4]] ``` For comparison, a pure numpy version of this function can be found in `util_test.py`, function `_fill_triangular`. Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: tril: `Tensor` with lower (or upper) triangular elements filled from `x`. Raises: ValueError: if `x` cannot be mapped to a triangular matrix. """ with ops.name_scope(name, "fill_triangular", values=[x]): x = ops.convert_to_tensor(x, name="x") if x.shape.with_rank_at_least(1)[-1].value is not None: # Formula derived by solving for n: m = n(n+1)/2. m = np.int32(x.shape[-1].value) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError("Input right-most shape ({}) does not " "correspond to a triangular matrix.".format(m)) n = np.int32(n) static_final_shape = x.shape[:-1].concatenate([n, n]) else: m = array_ops.shape(x)[-1] # For derivation, see above. Casting automatically lops off the 0.5, so we # omit it. We don't validate n is an integer because this has # graph-execution cost; an error will be thrown from the reshape, below. n = math_ops.cast( math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)), dtype=dtypes.int32) static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate( [None, None]) # We now concatenate the "tail" of `x` to `x` (and reverse one of them). # # We do this based on the insight that the input `x` provides `ceil(n/2)` # rows of an `n x n` matrix, some of which will get zeroed out being on the # wrong side of the diagonal. The first row will not get zeroed out at all, # and we need `floor(n/2)` more rows, so the first is what we omit from # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)` # rows provided by a reversed tail, it is exactly the other set of elements # of the reversed tail which will be zeroed out for being on the wrong side # of the diagonal further up/down the matrix. And, in doing-so, we've filled # the triangular matrix in a clock-wise spiral pattern. Neat! # # Try it out in numpy: # n = 3 # x = np.arange(n * (n + 1) / 2) # m = x.shape[0] # n = np.int32(np.sqrt(.25 + 2 * m) - .5) # x_tail = x[(m - (n**2 - m)):] # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower # # ==> array([[3, 4, 5], # [5, 4, 3], # [2, 1, 0]]) # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper # # ==> array([[0, 1, 2], # [3, 4, 5], # [5, 4, 3]]) # # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. # Furthermore observe that: # m - (n**2 - m) # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n ndims = prefer_static_rank(x) if upper: x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] else: x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] new_shape = ( static_final_shape.as_list() if static_final_shape.is_fully_defined() else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0)) x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape) x = array_ops.matrix_band_part( x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) x.set_shape(static_final_shape) return x def fill_triangular_inverse(x, upper=False, name=None): """Creates a vector from a (batch of) triangular matrix. The vector is created from the lower-triangular or upper-triangular portion depending on the value of the parameter `upper`. If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. Example: ```python fill_triangular_inverse( [[4, 0, 0], [6, 5, 0], [3, 2, 1]]) # ==> [1, 2, 3, 4, 5, 6] fill_triangular_inverse( [[1, 2, 3], [0, 5, 6], [0, 0, 4]], upper=True) # ==> [1, 2, 3, 4, 5, 6] ``` Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower (or upper) triangular elements from `x`. """ with ops.name_scope(name, "fill_triangular_inverse", values=[x]): x = ops.convert_to_tensor(x, name="x") if x.shape.with_rank_at_least(2)[-1].value is not None: n = np.int32(x.shape[-1].value) m = np.int32((n * (n + 1)) // 2) static_final_shape = x.shape[:-2].concatenate([m]) else: n = array_ops.shape(x)[-1] m = (n * (n + 1)) // 2 static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( [None]) ndims = prefer_static_rank(x) if upper: initial_elements = x[..., 0, :] triangular_portion = x[..., 1:, :] else: initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) triangular_portion = x[..., :-1, :] rotated_triangular_portion = array_ops.reverse( array_ops.reverse(triangular_portion, axis=[ndims - 1]), axis=[ndims - 2]) consolidated_matrix = triangular_portion + rotated_triangular_portion end_sequence = array_ops.reshape( consolidated_matrix, array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) y.set_shape(static_final_shape) return y def tridiag(below=None, diag=None, above=None, name=None): """Creates a matrix with values set above, below, and on the diagonal. Example: ```python tridiag(below=[1., 2., 3.], diag=[4., 5., 6., 7.], above=[8., 9., 10.]) # ==> array([[ 4., 8., 0., 0.], # [ 1., 5., 9., 0.], # [ 0., 2., 6., 10.], # [ 0., 0., 3., 7.]], dtype=float32) ``` Warning: This Op is intended for convenience, not efficiency. Args: below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below diagonal part. `None` is logically equivalent to `below = 0`. diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal part. `None` is logically equivalent to `diag = 0`. above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above diagonal part. `None` is logically equivalent to `above = 0`. name: Python `str`. The name to give this op. Returns: tridiag: `Tensor` with values set above, below and on the diagonal. Raises: ValueError: if all inputs are `None`. """ def _pad(x): """Prepends and appends a zero to every vector in a batch of vectors.""" shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0) z = array_ops.zeros(shape, dtype=x.dtype) return array_ops.concat([z, x, z], axis=-1) def _add(*x): """Adds list of Tensors, ignoring `None`.""" s = None for y in x: if y is None: continue elif s is None: s = y else: s += y if s is None: raise ValueError("Must specify at least one of `below`, `diag`, `above`.") return s with ops.name_scope(name, "tridiag", [below, diag, above]): if below is not None: below = ops.convert_to_tensor(below, name="below") below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:] if diag is not None: diag = ops.convert_to_tensor(diag, name="diag") diag = array_ops.matrix_diag(diag) if above is not None: above = ops.convert_to_tensor(above, name="above") above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1] # TODO(jvdillon): Consider using scatter_nd instead of creating three full # matrices. return _add(below, diag, above) def reduce_weighted_logsumexp( logx, w=None, axis=None, keep_dims=False, return_sign=False, name=None): """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. If all weights `w` are known to be positive, it is more efficient to directly use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more efficient than `du.reduce_weighted_logsumexp(logx, w)`. Reduces `input_tensor` along the dimensions given in `axis`. Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than log(sum(w * exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. For example: ```python x = tf.constant([[0., 0, 0], [0, 0, 0]]) w = tf.constant([[-1., 1, 1], [1, 1, 1]]) du.reduce_weighted_logsumexp(x, w) # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) du.reduce_weighted_logsumexp(x, w, axis=0) # ==> [log(-1+1), log(1+1), log(1+1)] du.reduce_weighted_logsumexp(x, w, axis=1) # ==> [log(-1+1+1), log(1+1+1)] du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) # ==> [[log(-1+1+1)], [log(1+1+1)]] du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) # ==> log(-1+5) ``` Args: logx: The tensor to reduce. Should have numeric type. w: The weight tensor. Should have numeric type identical to `logx`. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. keep_dims: If true, retains reduced dimensions with length 1. return_sign: If `True`, returns the sign of the result. name: A name for the operation (optional). Returns: lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. sign: (Optional) The sign of `sum(weight * exp(x))`. """ with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): logx = ops.convert_to_tensor(logx, name="logx") if w is None: lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) if return_sign: sgn = array_ops.ones_like(lswe) return lswe, sgn return lswe w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") log_absw_x = logx + math_ops.log(math_ops.abs(w)) max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) # If the largest element is `-inf` or `inf` then we don't bother subtracting # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That # this is ok follows from the fact that we're actually free to subtract any # value we like, so long as we add it back after taking the `log(sum(...))`. max_log_absw_x = array_ops.where( math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x), max_log_absw_x) wx_over_max_absw_x = ( math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) sum_wx_over_max_absw_x = math_ops.reduce_sum( wx_over_max_absw_x, axis=axis, keepdims=keep_dims) if not keep_dims: max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) sgn = math_ops.sign(sum_wx_over_max_absw_x) lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x) if return_sign: return lswe, sgn return lswe # TODO(jvdillon): Merge this test back into: # tensorflow/python/ops/softplus_op_test.py # once TF core is accepting new ops. def softplus_inverse(x, name=None): """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). Mathematically this op is equivalent to: ```none softplus_inverse = log(exp(x) - 1.) ``` Args: x: `Tensor`. Non-negative (not enforced), floating-point. name: A name for the operation (optional). Returns: `Tensor`. Has the same type/shape as input `x`. """ with ops.name_scope(name, "softplus_inverse", values=[x]): x = ops.convert_to_tensor(x, name="x") # We begin by deriving a more numerically stable softplus_inverse: # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). # ==> exp{x} = 1 + exp{y} (1) # ==> y = Log[exp{x} - 1] (2) # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] # = Log[1 - exp{-x}] + x (3) # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. # # In addition to the numerically stable derivation above, we clamp # small/large values to be congruent with the logic in: # tensorflow/core/kernels/softplus_op.h # # Finally, we set the input to one whenever the input is too large or too # small. This ensures that no unchosen codepath is +/- inf. This is # necessary to ensure the gradient doesn't get NaNs. Recall that the # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful # to overwrite `x` with ones only when we will never actually use this # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2. is_too_small = math_ops.less(x, np.exp(threshold)) is_too_large = math_ops.greater(x, -threshold) too_small_value = math_ops.log(x) too_large_value = x # This `where` will ultimately be a NOP because we won't select this # codepath whenever we used the surrogate `ones_like`. x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x), x) y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x)) return array_ops.where(is_too_small, too_small_value, array_ops.where(is_too_large, too_large_value, y)) # TODO(b/35290280): Add unit-tests. def dimension_size(x, axis): """Returns the size of a specific dimension.""" # Since tf.gather isn't "constant-in, constant-out", we must first check the # static shape or fallback to dynamic shape. s = x.shape.with_rank_at_least(np.abs(axis))[axis].value if s is not None: return s return array_ops.shape(x)[axis] def process_quadrature_grid_and_probs( quadrature_grid_and_probs, dtype, validate_args, name=None): """Validates quadrature grid, probs or computes them as necessary. Args: quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s representing the sample points and the corresponding (possibly normalized) weight. When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`. dtype: The expected `dtype` of `grid` and `probs`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s representing the sample points and the corresponding (possibly normalized) weight. Raises: ValueError: if `quadrature_grid_and_probs is not None` and `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` """ with ops.name_scope(name, "process_quadrature_grid_and_probs", [quadrature_grid_and_probs]): if quadrature_grid_and_probs is None: grid, probs = np.polynomial.hermite.hermgauss(deg=8) grid = grid.astype(dtype.as_numpy_dtype) probs = probs.astype(dtype.as_numpy_dtype) probs /= np.linalg.norm(probs, ord=1, keepdims=True) grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) return grid, probs grid, probs = tuple(quadrature_grid_and_probs) grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") def _static_event_size(x): """Returns the static size of a specific dimension or `None`.""" return x.shape.with_rank_at_least(1)[-1].value m, n = _static_event_size(probs), _static_event_size(grid) if m is not None and n is not None: if m != n: raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " "same-length zero-th-dimension `Tensor`s " "(saw lengths {}, {})".format(m, n)) elif validate_args: assertions = [ check_ops.assert_equal( dimension_size(probs, axis=-1), dimension_size(grid, axis=-1), message=("`quadrature_grid_and_probs` must be a `tuple` of " "same-length zero-th-dimension `Tensor`s")), ] with ops.control_dependencies(assertions): grid = array_ops.identity(grid) probs = array_ops.identity(probs) return grid, probs def pad(x, axis, front=False, back=False, value=0, count=1, name=None): """Pads `value` to the front and/or back of a `Tensor` dim, `count` times. Args: x: `Tensor` input. axis: Scalar `int`-like `Tensor` representing the single dimension to pad. (Negative indexing is supported.) front: Python `bool`; if `True` the beginning of the `axis` dimension is padded with `value`, `count` times. If `False` no front padding is made. back: Python `bool`; if `True` the end of the `axis` dimension is padded with `value`, `count` times. If `False` no end padding is made. value: Scalar `int`-like `Tensor` representing the actual value added to the front and/or back of the `axis` dimension of `x`. count: Scalar `int`-like `Tensor` representing number of elements added to the front and/or back of the `axis` dimension of `x`. E.g., if `front = back = True` then `2 * count` elements are added. name: Python `str` name prefixed to Ops created by this function. Returns: pad: The padded version of input `x`. Raises: ValueError: if both `front` and `back` are `False`. TypeError: if `count` is not `int`-like. """ with ops.name_scope(name, "pad", [x, value, count]): x = ops.convert_to_tensor(x, name="x") value = ops.convert_to_tensor(value, dtype=x.dtype, name="value") count = ops.convert_to_tensor(count, name="count") if not count.dtype.is_integer: raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format( count.dtype.name)) if not front and not back: raise ValueError("At least one of `front`, `back` must be `True`.") ndims = (x.shape.ndims if x.shape.ndims is not None else array_ops.rank(x, name="ndims")) axis = ops.convert_to_tensor(axis, name="axis") axis_ = tensor_util.constant_value(axis) if axis_ is not None: axis = axis_ if axis < 0: axis = ndims + axis count_ = tensor_util.constant_value(count) if axis_ >= 0 or x.shape.ndims is not None: head = x.shape[:axis] middle = tensor_shape.TensorShape( None if count_ is None else (x.shape[axis] + count_ * (front + back))) tail = x.shape[axis+1:] final_shape = head.concatenate(middle.concatenate(tail)) else: final_shape = None else: axis = array_ops.where(axis < 0, ndims + axis, axis) final_shape = None x = array_ops.pad( x, paddings=array_ops.one_hot( indices=array_ops.stack([axis if front else -1, axis if back else -1]), depth=ndims, axis=0, on_value=count, dtype=dtypes.int32), constant_values=value) if final_shape is not None: x.set_shape(final_shape) return x def parent_frame_arguments(): """Returns parent frame arguments. When called inside a function, returns a dictionary with the caller's function arguments. These are positional arguments and keyword arguments (**kwargs), while variable arguments (*varargs) are excluded. When called at global scope, this will return an empty dictionary, since there are no arguments. WARNING: If caller function argument names are overloaded before invoking this method, then values will reflect the overloaded value. For this reason, we recommend calling `parent_frame_arguments` at the beginning of the function. """ # All arguments and the names used for *varargs, and **kwargs arg_names, variable_arg_name, keyword_arg_name, local_vars = ( tf_inspect._inspect.getargvalues( # pylint: disable=protected-access # Get the first frame of the caller of this method. tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access # Remove the *varargs, and flatten the **kwargs. Both are # nested lists. local_vars.pop(variable_arg_name, {}) keyword_args = local_vars.pop(keyword_arg_name, {}) final_args = {} # Copy over arguments and their values. In general, local_vars # may contain more than just the arguments, since this method # can be called anywhere in a function. for arg_name in arg_names: final_args[arg_name] = local_vars.pop(arg_name) final_args.update(keyword_args) return final_args class AppendDocstring(object): """Helper class to promote private subclass docstring to public counterpart. Example: ```python class TransformedDistribution(Distribution): @distribution_util.AppendDocstring( additional_note="A special note!", kwargs_dict={"foo": "An extra arg."}) def _prob(self, y, foo=None): pass ``` In this case, the `AppendDocstring` decorator appends the `additional_note` to the docstring of `prob` (not `_prob`) and adds a new `kwargs` section with each dictionary item as a bullet-point. For a more detailed example, see `TransformedDistribution`. """ def __init__(self, additional_note="", kwargs_dict=None): """Initializes the AppendDocstring object. Args: additional_note: Python string added as additional docstring to public version of function. kwargs_dict: Python string/string dictionary representing specific kwargs expanded from the **kwargs input. Raises: ValueError: if kwargs_dict.key contains whitespace. ValueError: if kwargs_dict.value contains newlines. """ self._additional_note = additional_note if kwargs_dict: bullets = [] for key in sorted(kwargs_dict.keys()): value = kwargs_dict[key] if any(x.isspace() for x in key): raise ValueError( "Parameter name \"%s\" contains whitespace." % key) value = value.lstrip() if "\n" in value: raise ValueError( "Parameter description for \"%s\" contains newlines." % key) bullets.append("* `%s`: %s" % (key, value)) self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets)) def __call__(self, fn): @functools.wraps(fn) def _fn(*args, **kwargs): return fn(*args, **kwargs) if _fn.__doc__ is None: _fn.__doc__ = self._additional_note else: _fn.__doc__ += "\n%s" % self._additional_note return _fn