# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Decorator to overrides the gradient for a function.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape as tape_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @tf_export("custom_gradient") def custom_gradient(f): """Decorator to define a function with a custom gradient. This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations. For example, consider the following function that commonly occurs in the computation of cross entropy and log likelihoods: ```python def log1pexp(x): return tf.log(1 + tf.exp(x)) ``` Due to numerical instability, the gradient this function evaluated at x=100 is NaN. For example: ```python x = tf.constant(100.) y = log1pexp(x) dy = tf.gradients(y, x) # Will be NaN when evaluated. ``` The gradient expression can be analytically simplified to provide numerical stability: ```python @tf.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.log(1 + e), grad ``` With this definition, the gradient at x=100 will be correctly evaluated as 1.0. See also @{tf.RegisterGradient} which registers a gradient function for a primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows for fine grained control over the gradient computation of a sequence of operations. Note that if the decorated function uses `Variable`s, the enclosing variable scope must be using `ResourceVariable`s. Args: f: function `f(x)` that returns a tuple `(y, grad_fn)` where: - `x` is a `Tensor` or sequence of `Tensor` inputs to the function. - `y` is a `Tensor` or sequence of `Tensor` outputs of applying TensorFlow operations in `f` to `x`. - `grad_fn` is a function with the signature `g(*grad_ys)` which returns a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of `Tensor`s the same size as `y` holding the initial value gradients for each `Tensor` in `y`. If `f` uses `Variable`s (that are not part of the inputs), i.e. through `get_variable`, then `grad_fn` should have signature `g(*grad_ys, variables=None)`, where `variables` is a list of the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where `grad_xs` is the same as above, and `grad_vars` is a `list` with the derivatives of `Tensor`s in `y` with respect to the variables. Returns: A function `h(x)` which returns the same value as `f(x)[0]` and whose gradient (as calculated by @{tf.gradients}) is determined by `f(x)[1]`. """ def decorated(*args, **kwargs): """Decorated function with custom gradient.""" if context.executing_eagerly(): return _eager_mode_decorator(f, *args, **kwargs) else: return _graph_mode_decorator(f, *args, **kwargs) return tf_decorator.make_decorator(f, decorated) def _graph_mode_decorator(f, *args, **kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() args = [ops.convert_to_tensor(x) for x in args] # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set(current_var_scope.global_variables() + current_var_scope.local_variables()) with backprop.GradientTape() as tape: result, grad_fn = f(*args) after_vars = set(current_var_scope.global_variables() + current_var_scope.local_variables()) new_vars = after_vars - before_vars for v in new_vars: if not isinstance(v, resource_variable_ops.ResourceVariable): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = list(set(tape.watched_variables()) - set(args)) grad_argspec = tf_inspect.getargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or grad_argspec.keywords) if variables and not variables_in_signature: raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. if not variable_scope.get_variable_scope().use_resource: raise TypeError("If using @custom_gradient with a function that " "uses variables, the enclosing variable scope must " "have use_resource=True.") else: logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) all_tensors = flat_result + args + variables @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" result_grads = result_grads[:len(flat_result)] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * len(flat_result)) + input_grads + variable_grads with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) def _eager_mode_decorator(f, *args, **kwargs): """Implement custom gradient decorator for eager mode.""" with backprop.GradientTape() as tape: result, grad_fn = f(*args, **kwargs) all_inputs = list(args) + list(kwargs.values()) # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = [v for v in set(tape.watched_variables()) if v not in all_inputs] grad_argspec = tf_inspect.getargspec(grad_fn) if (variables and not ("variables" in grad_argspec.args or grad_argspec.keywords)): raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") flat_result = nest.flatten(result) # TODO(apassos) consider removing the identity below. flat_result = [gen_array_ops.identity(x) for x in flat_result] def actual_grad_fn(*result_grads): """Custom grad fn wrapper.""" if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] return nest.flatten(input_grads) + variable_grads input_tensors = [ops.convert_to_tensor(x) for x in list(args) + list(variables)] tape_lib.record_operation(f.__name__, flat_result, input_tensors, actual_grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(result, flat_result)