laywerrobot/lib/python3.6/site-packages/tensorflow/python/ops/custom_gradient.py
2020-08-27 21:55:39 +02:00

224 lines
9.5 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decorator to overrides the gradient for a function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@tf_export("custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
This decorator allows fine grained control over the gradients of a sequence
for operations. This may be useful for multiple reasons, including providing
a more efficient or numerically stable gradient for a sequence of operations.
For example, consider the following function that commonly occurs in the
computation of cross entropy and log likelihoods:
```python
def log1pexp(x):
return tf.log(1 + tf.exp(x))
```
Due to numerical instability, the gradient this function evaluated at x=100 is
NaN. For example:
```python
x = tf.constant(100.)
y = log1pexp(x)
dy = tf.gradients(y, x) # Will be NaN when evaluated.
```
The gradient expression can be analytically simplified to provide numerical
stability:
```python
@tf.custom_gradient
def log1pexp(x):
e = tf.exp(x)
def grad(dy):
return dy * (1 - 1 / (1 + e))
return tf.log(1 + e), grad
```
With this definition, the gradient at x=100 will be correctly evaluated as
1.0.
See also @{tf.RegisterGradient} which registers a gradient function for a
primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
for fine grained control over the gradient computation of a sequence of
operations.
Note that if the decorated function uses `Variable`s, the enclosing variable
scope must be using `ResourceVariable`s.
Args:
f: function `f(x)` that returns a tuple `(y, grad_fn)` where:
- `x` is a `Tensor` or sequence of `Tensor` inputs to the function.
- `y` is a `Tensor` or sequence of `Tensor` outputs of applying
TensorFlow
operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
`Tensor`s the same size as `y` holding the initial value gradients for
each `Tensor` in `y`. If `f` uses `Variable`s (that are not part of the
inputs), i.e. through `get_variable`, then `grad_fn` should have
signature `g(*grad_ys, variables=None)`, where `variables` is a list of
the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
`grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
with the derivatives of `Tensor`s in `y` with respect to the variables.
Returns:
A function `h(x)` which returns the same value as `f(x)[0]` and whose
gradient (as calculated by @{tf.gradients}) is determined by `f(x)[1]`.
"""
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
if context.executing_eagerly():
return _eager_mode_decorator(f, *args, **kwargs)
else:
return _graph_mode_decorator(f, *args, **kwargs)
return tf_decorator.make_decorator(f, decorated)
def _graph_mode_decorator(f, *args, **kwargs):
"""Implement custom gradient decorator for graph mode."""
# TODO(rsepassi): Add support for kwargs
if kwargs:
raise ValueError(
"The custom_gradient decorator currently supports keywords "
"arguments only when eager execution is enabled.")
name = "CustomGradient-%s" % ops.uid()
args = [ops.convert_to_tensor(x) for x in args]
# Checking global and local variables attempts to ensure that no non-resource
# Variables are added to the graph.
current_var_scope = variable_scope.get_variable_scope()
before_vars = set(current_var_scope.global_variables() +
current_var_scope.local_variables())
with backprop.GradientTape() as tape:
result, grad_fn = f(*args)
after_vars = set(current_var_scope.global_variables() +
current_var_scope.local_variables())
new_vars = after_vars - before_vars
for v in new_vars:
if not isinstance(v, resource_variable_ops.ResourceVariable):
raise TypeError(
"All variables used by a function wrapped with @custom_gradient must "
"be `ResourceVariable`s. Ensure that no `variable_scope` is created "
"with `use_resource=False`.")
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
variables = list(set(tape.watched_variables()) - set(args))
grad_argspec = tf_inspect.getargspec(grad_fn)
variables_in_signature = ("variables" in grad_argspec.args or
grad_argspec.keywords)
if variables and not variables_in_signature:
raise TypeError("If using @custom_gradient with a function that "
"uses variables, then grad_fn must accept a keyword "
"argument 'variables'.")
if variables_in_signature and not variables:
# User seems to intend to use variables but none were captured.
if not variable_scope.get_variable_scope().use_resource:
raise TypeError("If using @custom_gradient with a function that "
"uses variables, the enclosing variable scope must "
"have use_resource=True.")
else:
logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
"no ResourceVariables were used on the forward pass.")
flat_result = nest.flatten(result)
all_tensors = flat_result + args + variables
@ops.RegisterGradient(name)
def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable
"""Custom grad fn wrapper."""
result_grads = result_grads[:len(flat_result)]
if variables:
input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
if len(variable_grads) != len(variables):
raise ValueError("Must return gradient for each variable from "
"@custom_gradient grad_fn.")
else:
input_grads = grad_fn(*result_grads)
variable_grads = []
# Need to return one value per input to the IdentityN, so pad the
# gradients of the inputs of the custom_gradient function with the
# gradients of the outputs as well.
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
def _eager_mode_decorator(f, *args, **kwargs):
"""Implement custom gradient decorator for eager mode."""
with backprop.GradientTape() as tape:
result, grad_fn = f(*args, **kwargs)
all_inputs = list(args) + list(kwargs.values())
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
variables = [v for v in set(tape.watched_variables()) if v not in all_inputs]
grad_argspec = tf_inspect.getargspec(grad_fn)
if (variables and
not ("variables" in grad_argspec.args or grad_argspec.keywords)):
raise TypeError("If using @custom_gradient with a function that "
"uses variables, then grad_fn must accept a keyword "
"argument 'variables'.")
flat_result = nest.flatten(result)
# TODO(apassos) consider removing the identity below.
flat_result = [gen_array_ops.identity(x) for x in flat_result]
def actual_grad_fn(*result_grads):
"""Custom grad fn wrapper."""
if variables:
input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
if len(variable_grads) != len(variables):
raise ValueError("Must return gradient for each variable from "
"@custom_gradient grad_fn.")
else:
input_grads = grad_fn(*result_grads)
variable_grads = []
return nest.flatten(input_grads) + variable_grads
input_tensors = [ops.convert_to_tensor(x) for x
in list(args) + list(variables)]
tape_lib.record_operation(f.__name__, flat_result, input_tensors,
actual_grad_fn)
flat_result = list(flat_result)
return nest.pack_sequence_as(result, flat_result)