439 lines
17 KiB
Python
439 lines
17 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Decorator that produces a callable object that executes a TensorFlow graph.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import contextlib
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.eager import tape
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops as tf_ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import tf_decorator
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
def _default_initializer(name, shape, dtype):
|
|
"""The default initializer for variables."""
|
|
# pylint: disable=protected-access
|
|
store = variable_scope._get_default_variable_store()
|
|
initializer = store._get_default_initializer(name, shape=shape, dtype=dtype)
|
|
# pylint: enable=protected-access
|
|
return initializer[0]
|
|
|
|
|
|
class _CapturedVariable(object):
|
|
"""Variable captured by graph_callable.
|
|
|
|
Internal to the implementation of graph_callable. Created only by
|
|
_VariableCapturingScope and used only to read the variable values when calling
|
|
the function after the variables are initialized.
|
|
"""
|
|
|
|
def __init__(self, name, initializer, shape, dtype, trainable):
|
|
self.name = name
|
|
if initializer is None:
|
|
initializer = _default_initializer(name, shape, dtype)
|
|
initial_value = lambda: initializer(shape, dtype=dtype)
|
|
|
|
with context.eager_mode():
|
|
self.variable = resource_variable_ops.ResourceVariable(
|
|
initial_value=initial_value, name=name, dtype=dtype,
|
|
trainable=trainable)
|
|
self.shape = shape
|
|
self.dtype = dtype
|
|
self.placeholder = None
|
|
self.trainable = trainable
|
|
|
|
def read(self, want_gradients=True):
|
|
if want_gradients and self.trainable:
|
|
v = tape.watch_variable(self.variable)
|
|
else:
|
|
v = self.variable
|
|
return v.read_value()
|
|
|
|
|
|
class _VariableCapturingScope(object):
|
|
"""Variable-scope-like object which captures tf.get_variable calls.
|
|
|
|
This is responsible for the main difference between the initialization version
|
|
of a function object and the calling version of a function object.
|
|
|
|
capturing_scope replaces calls to tf.get_variable with placeholder tensors to
|
|
be fed the variable's current value. TODO(apassos): these placeholders should
|
|
instead be objects implementing a similar API to tf.Variable, for full
|
|
compatibility.
|
|
|
|
initializing_scope replaces calls to tf.get_variable with creation of
|
|
variables and initialization of their values. This allows eventual support of
|
|
initialized_value and friends.
|
|
|
|
TODO(apassos): once the eager mode layers API is implemented support eager
|
|
func-to-object as well.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.variables = {}
|
|
self.tf_variables = {}
|
|
|
|
@contextlib.contextmanager
|
|
def capturing_scope(self):
|
|
"""Context manager to capture variable creations.
|
|
|
|
Replaces variable accesses with placeholders.
|
|
|
|
Yields:
|
|
nothing
|
|
"""
|
|
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
|
|
# how to deal with these.
|
|
def _custom_getter( # pylint: disable=missing-docstring
|
|
getter=None,
|
|
name=None,
|
|
shape=None,
|
|
dtype=dtypes.float32,
|
|
initializer=None,
|
|
regularizer=None,
|
|
reuse=None,
|
|
trainable=None,
|
|
collections=None,
|
|
caching_device=None, # pylint: disable=redefined-outer-name
|
|
partitioner=None,
|
|
validate_shape=True,
|
|
use_resource=None,
|
|
aggregation=variable_scope.VariableAggregation.NONE,
|
|
synchronization=variable_scope.VariableSynchronization.AUTO):
|
|
del getter, regularizer, partitioner, validate_shape, use_resource, dtype
|
|
del collections, initializer, trainable, reuse, caching_device, shape
|
|
del aggregation, synchronization
|
|
assert name in self.variables
|
|
v = self.variables[name]
|
|
return v.variable
|
|
|
|
scope = variable_scope.get_variable_scope()
|
|
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
|
|
yield
|
|
|
|
@contextlib.contextmanager
|
|
def initializing_scope(self):
|
|
"""Context manager to capture variable creations.
|
|
|
|
Forcibly initializes all created variables.
|
|
|
|
Yields:
|
|
nothing
|
|
"""
|
|
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
|
|
# how to deal with these.
|
|
def _custom_getter( # pylint: disable=missing-docstring
|
|
getter=None,
|
|
name=None,
|
|
shape=None,
|
|
dtype=dtypes.float32,
|
|
initializer=None,
|
|
regularizer=None,
|
|
reuse=None,
|
|
trainable=None,
|
|
collections=None,
|
|
caching_device=None, # pylint: disable=redefined-outer-name
|
|
partitioner=None,
|
|
validate_shape=True,
|
|
use_resource=None,
|
|
aggregation=variable_scope.VariableAggregation.NONE,
|
|
synchronization=variable_scope.VariableSynchronization.AUTO):
|
|
del getter, regularizer, collections, caching_device, partitioner
|
|
del use_resource, validate_shape, aggregation, synchronization
|
|
if name in self.tf_variables:
|
|
if reuse:
|
|
return self.tf_variables[name].initialized_value()
|
|
else:
|
|
raise ValueError("Specified reuse=%s but tried to reuse variables."
|
|
% reuse)
|
|
# TODO(apassos): ensure this is on the same device as above
|
|
v = _CapturedVariable(name, initializer, shape, dtype, trainable)
|
|
self.variables[name] = v
|
|
|
|
graph_mode_resource = v.variable.handle
|
|
if initializer is None:
|
|
initializer = _default_initializer(name, shape, dtype)
|
|
resource_variable_ops.shape_safe_assign_variable_handle(
|
|
graph_mode_resource, v.variable.shape, initializer(shape, dtype))
|
|
return v.variable
|
|
|
|
scope = variable_scope.get_variable_scope()
|
|
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
|
|
yield
|
|
|
|
|
|
class _InitializingFunctionObject(object):
|
|
"""Responsible for deciding which version of func-to-object to call.
|
|
|
|
call_fn is the version which calls the function with the current values of the
|
|
variables and init_fn is the version which calls the function to initialize
|
|
all variables.
|
|
|
|
TODO(apassos): figure out a way to support initializing only _some_
|
|
variables. This requires a way to pull out a variable's initialization code
|
|
from the graph, which might not be possible in general.
|
|
"""
|
|
|
|
def __init__(self, call_fn, init_fn, shape_and_dtypes):
|
|
self._init_fn = init_fn
|
|
self._call_fn = call_fn
|
|
self.shape_and_dtypes = shape_and_dtypes
|
|
self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in
|
|
nest.flatten(self.shape_and_dtypes)]
|
|
|
|
@property
|
|
def variables(self):
|
|
return self._call_fn.variables
|
|
|
|
def __call__(self, *args):
|
|
nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
|
|
if not all([
|
|
shape.is_compatible_with(arg.shape)
|
|
for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
|
|
]):
|
|
raise ValueError(
|
|
"Declared shapes do not match argument shapes: Expected %s, found %s."
|
|
% (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))
|
|
|
|
initialized = [resource_variable_ops.var_is_initialized_op(
|
|
v.handle).numpy() for v in self._call_fn.variables]
|
|
if all(x for x in initialized):
|
|
for v in self._call_fn.variables:
|
|
if v.trainable:
|
|
tape.watch_variable(v)
|
|
return self._call_fn(*args)
|
|
elif all(not x for x in initialized):
|
|
return self._init_fn(*args)
|
|
else:
|
|
raise ValueError("Some, but not all, variables are initialized.")
|
|
|
|
|
|
def _get_graph_callable_inputs(shape_and_dtypes):
|
|
"""Maps specified shape_and_dtypes to graph inputs."""
|
|
ret = []
|
|
for x in shape_and_dtypes:
|
|
if isinstance(x, ShapeAndDtype):
|
|
ret.append(array_ops.placeholder(x.dtype, x.shape))
|
|
elif isinstance(x, (tuple, list)):
|
|
ret.append(_get_graph_callable_inputs(x))
|
|
else:
|
|
raise errors.InvalidArgumentError(
|
|
None, None, "Expected the argument to @graph_callable to be a "
|
|
"(possibly nested) list or tuple of ShapeAndDtype objects, "
|
|
"but got an object of type: %s" % type(x))
|
|
|
|
return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
|
|
|
|
|
|
def _graph_callable_internal(func, shape_and_dtypes):
|
|
"""Defines and returns a template version of func.
|
|
|
|
Under the hood we make two function objects, each wrapping a different version
|
|
of the graph-mode code. One version immediately runs variable initialization
|
|
before making the variable's Tensors available for use, while the other
|
|
version replaces the Variables with placeholders which become function
|
|
arguments and get the current variable's value.
|
|
|
|
Limitations in (2) and (4) are because this does not implement a graph-mode
|
|
Variable class which has a convert_to_tensor(as_ref=True) method and a
|
|
initialized_value method. This is fixable.
|
|
|
|
Args:
|
|
func: The tfe Python function to compile.
|
|
shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
|
|
|
|
Raises:
|
|
ValueError: If any one of func's outputs is not a Tensor.
|
|
|
|
Returns:
|
|
Callable graph object.
|
|
"""
|
|
container = tf_ops.get_default_graph()._container # pylint: disable=protected-access
|
|
graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access
|
|
with context.graph_mode():
|
|
# This graph will store both the initialization and the call version of the
|
|
# wrapped function. It will later be used by the backprop code to build the
|
|
# backprop graph, if necessary.
|
|
captures = {}
|
|
tmp_graph = function.CapturingGraph(captures)
|
|
# Inherit the graph key from the original graph to ensure optimizers don't
|
|
# misbehave.
|
|
tmp_graph._container = container # pylint: disable=protected-access
|
|
tmp_graph._graph_key = graph_key # pylint: disable=protected-access
|
|
with tmp_graph.as_default():
|
|
# Placeholders for the non-variable inputs.
|
|
func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
|
|
func_num_args = len(tf_inspect.getargspec(func).args)
|
|
if len(func_inputs) != func_num_args:
|
|
raise TypeError("The number of arguments accepted by the decorated "
|
|
"function `%s` (%d) must match the number of "
|
|
"ShapeAndDtype objects passed to the graph_callable() "
|
|
"decorator (%d)." %
|
|
(func.__name__, func_num_args, len(func_inputs)))
|
|
|
|
# First call the function to generate a graph which can initialize all
|
|
# variables. As a side-effect this will populate the variable capturing
|
|
# scope's view of which variables exist.
|
|
variable_captures = _VariableCapturingScope()
|
|
with variable_captures.initializing_scope(
|
|
), function.AutomaticControlDependencies() as a:
|
|
func_outputs = func(*func_inputs)
|
|
outputs_list = nest.flatten(func_outputs)
|
|
for i, x in enumerate(outputs_list):
|
|
if x is not None:
|
|
outputs_list[i] = a.mark_as_return(x)
|
|
if len(outputs_list) == 1 and outputs_list[0] is None:
|
|
outputs_list = []
|
|
output_shapes = [x.shape for x in outputs_list]
|
|
if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
|
|
raise ValueError("Found non-tensor output in %s" % str(outputs_list))
|
|
initializing_operations = tmp_graph.get_operations()
|
|
|
|
# Call the function again, now replacing usages of variables with
|
|
# placeholders. This assumes the variable capturing scope created above
|
|
# knows about all variables.
|
|
tmp_graph.clear_resource_control_flow_state()
|
|
with variable_captures.capturing_scope(
|
|
), function.AutomaticControlDependencies() as a:
|
|
captured_outputs = func(*func_inputs)
|
|
captured_outlist = nest.flatten(captured_outputs)
|
|
for i, x in enumerate(captured_outlist):
|
|
if x is not None:
|
|
captured_outlist[i] = a.mark_as_return(x)
|
|
capturing_operations = tmp_graph.get_operations()[
|
|
len(initializing_operations):]
|
|
|
|
sorted_variables = sorted(variable_captures.variables.values(),
|
|
key=lambda x: x.name)
|
|
ids = list(sorted(captures.keys()))
|
|
if ids:
|
|
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
|
|
else:
|
|
extra_inputs = []
|
|
extra_placeholders = []
|
|
|
|
flat_inputs = [x for x in nest.flatten(func_inputs)
|
|
if isinstance(x, tf_ops.Tensor)]
|
|
placeholder_inputs = flat_inputs+ list(extra_placeholders)
|
|
|
|
func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
|
|
initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access
|
|
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
|
|
# Also, what about the gradient registry of these functions? Those need to be
|
|
# addressed as well.
|
|
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
|
|
function._register(f._c_func.func) # pylint: disable=protected-access
|
|
initializer_function = function.GraphModeFunction(
|
|
initialization_name,
|
|
placeholder_inputs,
|
|
extra_inputs,
|
|
tmp_graph,
|
|
initializing_operations,
|
|
func_def_outputs,
|
|
func_outputs,
|
|
output_shapes)
|
|
|
|
capture_func_def_outputs = [
|
|
x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
|
|
captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access
|
|
captured_function = function.GraphModeFunction(
|
|
captured_function_name,
|
|
placeholder_inputs,
|
|
extra_inputs,
|
|
tmp_graph,
|
|
capturing_operations,
|
|
capture_func_def_outputs,
|
|
captured_outputs,
|
|
output_shapes,
|
|
variables=[x.variable for x in sorted_variables])
|
|
|
|
return _InitializingFunctionObject(captured_function, initializer_function,
|
|
shape_and_dtypes)
|
|
|
|
|
|
class ShapeAndDtype(object):
|
|
"""Data type that packages together shape and type information.
|
|
|
|
Used for arguments to graph callables. See graph_callable() for an example.
|
|
"""
|
|
|
|
def __init__(self, shape, dtype):
|
|
self.shape = shape
|
|
self.dtype = dtype
|
|
|
|
|
|
def graph_callable(shape_and_dtypes):
|
|
"""Decorator that produces a callable that executes a TensorFlow graph.
|
|
|
|
When applied on a function that constructs a TensorFlow graph, this decorator
|
|
produces a callable object that:
|
|
|
|
1. Executes the graph when invoked. The first call will initialize any
|
|
variables defined in the graph.
|
|
|
|
2. Provides a .variables() method to return the list of TensorFlow variables
|
|
defined in the graph.
|
|
|
|
Note that the wrapped function is not allowed to change the values of the
|
|
variables, just use them.
|
|
|
|
The return value of the wrapped function must be one of the following:
|
|
(1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
|
|
|
|
Example:
|
|
|
|
```python
|
|
@tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)])
|
|
def foo(x):
|
|
v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=())
|
|
return v + x
|
|
|
|
ret = foo(tfe.Tensor(2.0)) # `ret` here is a Tensor with value 3.0.
|
|
|
|
foo.variables[0].assign(7.0) # Modify the value of variable `v`.
|
|
ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0.
|
|
```
|
|
Args:
|
|
shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
|
|
that specifies shape and type information for each of the callable's
|
|
arguments. The length of this list must be equal to the number of
|
|
arguments accepted by the wrapped function.
|
|
|
|
Returns:
|
|
A callable graph object.
|
|
"""
|
|
# TODO(alive,apassos): support initialized_value and friends from tf.Variable.
|
|
assert context.executing_eagerly(), (
|
|
"graph_callable can only be used when Eager execution is enabled.")
|
|
def decorator(func):
|
|
return tf_decorator.make_decorator(func,
|
|
_graph_callable_internal(
|
|
func, shape_and_dtypes))
|
|
|
|
return decorator
|