laywerrobot/lib/python3.6/site-packages/tensorflow/python/eager/graph_callable.py
2020-08-27 21:55:39 +02:00

439 lines
17 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decorator that produces a callable object that executes a TensorFlow graph.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import tape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
def _default_initializer(name, shape, dtype):
"""The default initializer for variables."""
# pylint: disable=protected-access
store = variable_scope._get_default_variable_store()
initializer = store._get_default_initializer(name, shape=shape, dtype=dtype)
# pylint: enable=protected-access
return initializer[0]
class _CapturedVariable(object):
"""Variable captured by graph_callable.
Internal to the implementation of graph_callable. Created only by
_VariableCapturingScope and used only to read the variable values when calling
the function after the variables are initialized.
"""
def __init__(self, name, initializer, shape, dtype, trainable):
self.name = name
if initializer is None:
initializer = _default_initializer(name, shape, dtype)
initial_value = lambda: initializer(shape, dtype=dtype)
with context.eager_mode():
self.variable = resource_variable_ops.ResourceVariable(
initial_value=initial_value, name=name, dtype=dtype,
trainable=trainable)
self.shape = shape
self.dtype = dtype
self.placeholder = None
self.trainable = trainable
def read(self, want_gradients=True):
if want_gradients and self.trainable:
v = tape.watch_variable(self.variable)
else:
v = self.variable
return v.read_value()
class _VariableCapturingScope(object):
"""Variable-scope-like object which captures tf.get_variable calls.
This is responsible for the main difference between the initialization version
of a function object and the calling version of a function object.
capturing_scope replaces calls to tf.get_variable with placeholder tensors to
be fed the variable's current value. TODO(apassos): these placeholders should
instead be objects implementing a similar API to tf.Variable, for full
compatibility.
initializing_scope replaces calls to tf.get_variable with creation of
variables and initialization of their values. This allows eventual support of
initialized_value and friends.
TODO(apassos): once the eager mode layers API is implemented support eager
func-to-object as well.
"""
def __init__(self):
self.variables = {}
self.tf_variables = {}
@contextlib.contextmanager
def capturing_scope(self):
"""Context manager to capture variable creations.
Replaces variable accesses with placeholders.
Yields:
nothing
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
def _custom_getter( # pylint: disable=missing-docstring
getter=None,
name=None,
shape=None,
dtype=dtypes.float32,
initializer=None,
regularizer=None,
reuse=None,
trainable=None,
collections=None,
caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None,
validate_shape=True,
use_resource=None,
aggregation=variable_scope.VariableAggregation.NONE,
synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, partitioner, validate_shape, use_resource, dtype
del collections, initializer, trainable, reuse, caching_device, shape
del aggregation, synchronization
assert name in self.variables
v = self.variables[name]
return v.variable
scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
yield
@contextlib.contextmanager
def initializing_scope(self):
"""Context manager to capture variable creations.
Forcibly initializes all created variables.
Yields:
nothing
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
def _custom_getter( # pylint: disable=missing-docstring
getter=None,
name=None,
shape=None,
dtype=dtypes.float32,
initializer=None,
regularizer=None,
reuse=None,
trainable=None,
collections=None,
caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None,
validate_shape=True,
use_resource=None,
aggregation=variable_scope.VariableAggregation.NONE,
synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, collections, caching_device, partitioner
del use_resource, validate_shape, aggregation, synchronization
if name in self.tf_variables:
if reuse:
return self.tf_variables[name].initialized_value()
else:
raise ValueError("Specified reuse=%s but tried to reuse variables."
% reuse)
# TODO(apassos): ensure this is on the same device as above
v = _CapturedVariable(name, initializer, shape, dtype, trainable)
self.variables[name] = v
graph_mode_resource = v.variable.handle
if initializer is None:
initializer = _default_initializer(name, shape, dtype)
resource_variable_ops.shape_safe_assign_variable_handle(
graph_mode_resource, v.variable.shape, initializer(shape, dtype))
return v.variable
scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
yield
class _InitializingFunctionObject(object):
"""Responsible for deciding which version of func-to-object to call.
call_fn is the version which calls the function with the current values of the
variables and init_fn is the version which calls the function to initialize
all variables.
TODO(apassos): figure out a way to support initializing only _some_
variables. This requires a way to pull out a variable's initialization code
from the graph, which might not be possible in general.
"""
def __init__(self, call_fn, init_fn, shape_and_dtypes):
self._init_fn = init_fn
self._call_fn = call_fn
self.shape_and_dtypes = shape_and_dtypes
self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in
nest.flatten(self.shape_and_dtypes)]
@property
def variables(self):
return self._call_fn.variables
def __call__(self, *args):
nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
if not all([
shape.is_compatible_with(arg.shape)
for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
]):
raise ValueError(
"Declared shapes do not match argument shapes: Expected %s, found %s."
% (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))
initialized = [resource_variable_ops.var_is_initialized_op(
v.handle).numpy() for v in self._call_fn.variables]
if all(x for x in initialized):
for v in self._call_fn.variables:
if v.trainable:
tape.watch_variable(v)
return self._call_fn(*args)
elif all(not x for x in initialized):
return self._init_fn(*args)
else:
raise ValueError("Some, but not all, variables are initialized.")
def _get_graph_callable_inputs(shape_and_dtypes):
"""Maps specified shape_and_dtypes to graph inputs."""
ret = []
for x in shape_and_dtypes:
if isinstance(x, ShapeAndDtype):
ret.append(array_ops.placeholder(x.dtype, x.shape))
elif isinstance(x, (tuple, list)):
ret.append(_get_graph_callable_inputs(x))
else:
raise errors.InvalidArgumentError(
None, None, "Expected the argument to @graph_callable to be a "
"(possibly nested) list or tuple of ShapeAndDtype objects, "
"but got an object of type: %s" % type(x))
return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
def _graph_callable_internal(func, shape_and_dtypes):
"""Defines and returns a template version of func.
Under the hood we make two function objects, each wrapping a different version
of the graph-mode code. One version immediately runs variable initialization
before making the variable's Tensors available for use, while the other
version replaces the Variables with placeholders which become function
arguments and get the current variable's value.
Limitations in (2) and (4) are because this does not implement a graph-mode
Variable class which has a convert_to_tensor(as_ref=True) method and a
initialized_value method. This is fixable.
Args:
func: The tfe Python function to compile.
shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
Raises:
ValueError: If any one of func's outputs is not a Tensor.
Returns:
Callable graph object.
"""
container = tf_ops.get_default_graph()._container # pylint: disable=protected-access
graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access
with context.graph_mode():
# This graph will store both the initialization and the call version of the
# wrapped function. It will later be used by the backprop code to build the
# backprop graph, if necessary.
captures = {}
tmp_graph = function.CapturingGraph(captures)
# Inherit the graph key from the original graph to ensure optimizers don't
# misbehave.
tmp_graph._container = container # pylint: disable=protected-access
tmp_graph._graph_key = graph_key # pylint: disable=protected-access
with tmp_graph.as_default():
# Placeholders for the non-variable inputs.
func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
func_num_args = len(tf_inspect.getargspec(func).args)
if len(func_inputs) != func_num_args:
raise TypeError("The number of arguments accepted by the decorated "
"function `%s` (%d) must match the number of "
"ShapeAndDtype objects passed to the graph_callable() "
"decorator (%d)." %
(func.__name__, func_num_args, len(func_inputs)))
# First call the function to generate a graph which can initialize all
# variables. As a side-effect this will populate the variable capturing
# scope's view of which variables exist.
variable_captures = _VariableCapturingScope()
with variable_captures.initializing_scope(
), function.AutomaticControlDependencies() as a:
func_outputs = func(*func_inputs)
outputs_list = nest.flatten(func_outputs)
for i, x in enumerate(outputs_list):
if x is not None:
outputs_list[i] = a.mark_as_return(x)
if len(outputs_list) == 1 and outputs_list[0] is None:
outputs_list = []
output_shapes = [x.shape for x in outputs_list]
if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
raise ValueError("Found non-tensor output in %s" % str(outputs_list))
initializing_operations = tmp_graph.get_operations()
# Call the function again, now replacing usages of variables with
# placeholders. This assumes the variable capturing scope created above
# knows about all variables.
tmp_graph.clear_resource_control_flow_state()
with variable_captures.capturing_scope(
), function.AutomaticControlDependencies() as a:
captured_outputs = func(*func_inputs)
captured_outlist = nest.flatten(captured_outputs)
for i, x in enumerate(captured_outlist):
if x is not None:
captured_outlist[i] = a.mark_as_return(x)
capturing_operations = tmp_graph.get_operations()[
len(initializing_operations):]
sorted_variables = sorted(variable_captures.variables.values(),
key=lambda x: x.name)
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
else:
extra_inputs = []
extra_placeholders = []
flat_inputs = [x for x in nest.flatten(func_inputs)
if isinstance(x, tf_ops.Tensor)]
placeholder_inputs = flat_inputs+ list(extra_placeholders)
func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
# Also, what about the gradient registry of these functions? Those need to be
# addressed as well.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
function._register(f._c_func.func) # pylint: disable=protected-access
initializer_function = function.GraphModeFunction(
initialization_name,
placeholder_inputs,
extra_inputs,
tmp_graph,
initializing_operations,
func_def_outputs,
func_outputs,
output_shapes)
capture_func_def_outputs = [
x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access
captured_function = function.GraphModeFunction(
captured_function_name,
placeholder_inputs,
extra_inputs,
tmp_graph,
capturing_operations,
capture_func_def_outputs,
captured_outputs,
output_shapes,
variables=[x.variable for x in sorted_variables])
return _InitializingFunctionObject(captured_function, initializer_function,
shape_and_dtypes)
class ShapeAndDtype(object):
"""Data type that packages together shape and type information.
Used for arguments to graph callables. See graph_callable() for an example.
"""
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
def graph_callable(shape_and_dtypes):
"""Decorator that produces a callable that executes a TensorFlow graph.
When applied on a function that constructs a TensorFlow graph, this decorator
produces a callable object that:
1. Executes the graph when invoked. The first call will initialize any
variables defined in the graph.
2. Provides a .variables() method to return the list of TensorFlow variables
defined in the graph.
Note that the wrapped function is not allowed to change the values of the
variables, just use them.
The return value of the wrapped function must be one of the following:
(1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
Example:
```python
@tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)])
def foo(x):
v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=())
return v + x
ret = foo(tfe.Tensor(2.0)) # `ret` here is a Tensor with value 3.0.
foo.variables[0].assign(7.0) # Modify the value of variable `v`.
ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0.
```
Args:
shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
that specifies shape and type information for each of the callable's
arguments. The length of this list must be equal to the number of
arguments accepted by the wrapped function.
Returns:
A callable graph object.
"""
# TODO(alive,apassos): support initialized_value and friends from tf.Variable.
assert context.executing_eagerly(), (
"graph_callable can only be used when Eager execution is enabled.")
def decorator(func):
return tf_decorator.make_decorator(func,
_graph_callable_internal(
func, shape_and_dtypes))
return decorator