1242 lines
42 KiB
Python
1242 lines
42 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# =============================================================================
|
||
|
"""Python front-end supports for functions.
|
||
|
|
||
|
NOTE: functions are currently experimental and subject to change!
|
||
|
"""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import collections
|
||
|
import hashlib
|
||
|
import sys
|
||
|
|
||
|
from tensorflow.core.framework import attr_value_pb2
|
||
|
from tensorflow.core.framework import function_pb2
|
||
|
from tensorflow.python import pywrap_tensorflow as c_api
|
||
|
from tensorflow.python.eager import context
|
||
|
from tensorflow.python.framework import c_api_util
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import graph_to_function_def
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.ops import cond_v2_impl
|
||
|
from tensorflow.python.ops import resource_variable_ops
|
||
|
from tensorflow.python.ops import variable_scope as vs
|
||
|
from tensorflow.python.util import compat
|
||
|
from tensorflow.python.util import tf_contextlib
|
||
|
from tensorflow.python.util import tf_decorator
|
||
|
from tensorflow.python.util import tf_inspect
|
||
|
|
||
|
# This is to avoid a circular dependency with cond_v2_impl.
|
||
|
cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
class Defun(object):
|
||
|
"""Decorator used to define TensorFlow functions.
|
||
|
|
||
|
Use this decorator to make a Python function usable directly as a TensorFlow
|
||
|
function.
|
||
|
|
||
|
The decorated function must add ops to the default graph and return zero or
|
||
|
more `Tensor` objects. Call the decorator with named arguments, one for each
|
||
|
argument of the function to decorate, with the expected type of the argument
|
||
|
as value.
|
||
|
|
||
|
For example if the function to decorate accepts two `tf.float32` arguments
|
||
|
named `x` and `y`, call the decorator with:
|
||
|
|
||
|
@Defun(tf.float32, tf.float32)
|
||
|
def foo(x, y):
|
||
|
...
|
||
|
|
||
|
When you call the decorated function it will add `call` ops to the
|
||
|
default graph and adds the definition of the function into the
|
||
|
default graph. Because the addition of the function into the graph
|
||
|
is deferred, the decorator can be used anywhere in the program.
|
||
|
|
||
|
Any variables created inside of the function are hoisted into the outer graph.
|
||
|
Note that the variables are created in the variable scope that was active
|
||
|
during the first call to the function. Subsequent function calls will refer to
|
||
|
the same set of variables.
|
||
|
|
||
|
Definitions of functions in a graph are frozen as soon as the graph is used to
|
||
|
create a session. However, new functions and new calls to existing functions
|
||
|
may be added to the graph, with the new functions themselves becoming
|
||
|
immediately frozen.
|
||
|
|
||
|
Example, but also see the [How To on functions](link_needed).
|
||
|
|
||
|
```python
|
||
|
# Defining the function.
|
||
|
@tf.Defun(tf.float32, tf.float32)
|
||
|
def MyFunc(x, y):
|
||
|
return x + y, x - y
|
||
|
|
||
|
# Building the graph.
|
||
|
a = tf.constant([1.0])
|
||
|
b = tf.constant([2.0])
|
||
|
c, d = MyFunc(a, b, name='mycall')
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *input_types, **kwargs):
|
||
|
"""Create a `Defun` decorator.
|
||
|
|
||
|
Args:
|
||
|
*input_types: A list of `tf.DType`
|
||
|
**kwargs: Optional keyword arguments, including
|
||
|
func_name - (optional). A python string, the name to use to
|
||
|
declare this `Function` in the graph.
|
||
|
|
||
|
grad_func - (optional). A function implementing the gradient
|
||
|
of the function-to-register. This is must be a
|
||
|
`_DefinedFunction` object. The gradient
|
||
|
function must satisfy the criterion defined in
|
||
|
function.proto:GradientDef.
|
||
|
|
||
|
python_grad_func - (optional). A function implementing the
|
||
|
gradient of the function python-side. This function must
|
||
|
take the current op and the gradients w.r.t. its outputs,
|
||
|
and return the gradients w.r.t. the inputs. That is it must
|
||
|
implement the interface expected by `tf.RegisterGradient`).
|
||
|
This will be called by tf.gradients to add the gradient ops
|
||
|
to the graph. At most one of grad_func and python_grad_func
|
||
|
can be specified.
|
||
|
|
||
|
out_names = (optional). A list of strings, one per output
|
||
|
tensor.
|
||
|
|
||
|
shape_func - (optional). A function taking the op and returning a list
|
||
|
of static shapes to set for the function's outputs.
|
||
|
"""
|
||
|
self._input_types = input_types
|
||
|
self._func_name = kwargs.pop("func_name", None)
|
||
|
self._grad_func = kwargs.pop("grad_func", None)
|
||
|
self._python_grad_func = kwargs.pop("python_grad_func", None)
|
||
|
self._out_names = kwargs.pop("out_names", None)
|
||
|
self._extra_kwargs = kwargs
|
||
|
|
||
|
def __call__(self, func):
|
||
|
# Various sanity checks on the callable func.
|
||
|
if not callable(func):
|
||
|
raise ValueError("func %s must be callable" % func)
|
||
|
|
||
|
# Func should not use kwargs and defaults.
|
||
|
argspec = tf_inspect.getargspec(func)
|
||
|
if argspec.keywords or argspec.defaults:
|
||
|
raise ValueError("Functions with argument defaults or keyword "
|
||
|
"arguments are not supported.")
|
||
|
|
||
|
# Computes how many arguments 'func' has.
|
||
|
min_args = len(argspec.args)
|
||
|
max_args = min_args
|
||
|
if argspec.varargs:
|
||
|
max_args = 1000000
|
||
|
argnames = argspec.args
|
||
|
if tf_inspect.ismethod(func):
|
||
|
# 1st argument is the "class" type.
|
||
|
min_args -= 1
|
||
|
argnames = argnames[1:]
|
||
|
|
||
|
if self._input_types:
|
||
|
# If Defun is given a list of types for the inputs, the number
|
||
|
# of input types should be compatible with 'func'.
|
||
|
num = len(self._input_types)
|
||
|
if num < min_args or num > max_args:
|
||
|
raise ValueError(
|
||
|
"The function has fewer arguments than the number of specified "
|
||
|
"input types.")
|
||
|
return _DefinedFunction(
|
||
|
func,
|
||
|
argnames,
|
||
|
self._input_types,
|
||
|
self._func_name,
|
||
|
self._grad_func,
|
||
|
self._python_grad_func,
|
||
|
out_names=self._out_names,
|
||
|
**self._extra_kwargs)
|
||
|
|
||
|
# 'func' expects no arguments and input types is an empty list.
|
||
|
if min_args == 0 and max_args == 0:
|
||
|
return _DefinedFunction(
|
||
|
func, [], [],
|
||
|
self._func_name,
|
||
|
self._grad_func,
|
||
|
self._python_grad_func,
|
||
|
out_names=self._out_names,
|
||
|
**self._extra_kwargs)
|
||
|
|
||
|
# Input types are unknown. It's an overloaded function and hence
|
||
|
# its definition needs to be deferred until it's called.
|
||
|
return _OverloadedFunction(
|
||
|
func,
|
||
|
argnames,
|
||
|
self._func_name,
|
||
|
self._grad_func,
|
||
|
self._python_grad_func,
|
||
|
out_names=self._out_names,
|
||
|
**self._extra_kwargs)
|
||
|
|
||
|
|
||
|
class _DefinedFunction(object):
|
||
|
"""_DefinedFunction encapsulates a function definition and its properties.
|
||
|
|
||
|
Attributes:
|
||
|
name: The function name.
|
||
|
definition: The definition of this function. A FunctionDef proto.
|
||
|
grad_func_name: If not None, the name of this function's gradient function.
|
||
|
python_grad_func: A python callable implementing the gradient of
|
||
|
the function python-side.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
func,
|
||
|
argnames,
|
||
|
input_types,
|
||
|
func_name=None,
|
||
|
grad_func=None,
|
||
|
python_grad_func=None,
|
||
|
out_names=None,
|
||
|
shape_func=None,
|
||
|
capture_by_value=False,
|
||
|
**kwargs):
|
||
|
"""Creates _DefinedFunction.
|
||
|
|
||
|
Args:
|
||
|
func: A python callable which constructs a tf function body.
|
||
|
argnames: A list of strings for function argument names.
|
||
|
input_types: The function's argument types. Can be a tuple, list of
|
||
|
tf data types.
|
||
|
func_name: The function name. Defaults to None, in which derives from
|
||
|
'func'.
|
||
|
grad_func: This function's gradient function, if not None. Defaults
|
||
|
to None.
|
||
|
python_grad_func: A python callable implementing the gradient of
|
||
|
the function python-side.
|
||
|
out_names: An optional list of strings for the function return value
|
||
|
names.
|
||
|
shape_func: An optional function mapping an op to a list of static
|
||
|
output shapes.
|
||
|
capture_by_value: Boolean (defaults to False). If True, captured values
|
||
|
will be copied into the function body.
|
||
|
**kwargs: The keyword arguments. **kwargs is passed to every call
|
||
|
site of this function.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: The function definition is invalid.
|
||
|
|
||
|
"""
|
||
|
self._func = func
|
||
|
self._input_types = input_types
|
||
|
self._func_name = func_name
|
||
|
self._grad_func = grad_func
|
||
|
self._python_grad_func = python_grad_func
|
||
|
self._out_names = out_names
|
||
|
self._shape_func = shape_func
|
||
|
self._capture_by_value = capture_by_value
|
||
|
self._extra_kwargs = kwargs
|
||
|
# Constructed only when C API is disabled, lazily
|
||
|
self._definition = None
|
||
|
# Constructed only when C API is enabled, lazily
|
||
|
self._c_func = None
|
||
|
self._sub_functions = dict() # Constructed with _definition or _c_func
|
||
|
device_stack = ops.get_default_graph()._device_function_stack # pylint: disable=protected-access
|
||
|
# Get the innermost device if possbile.
|
||
|
self._caller_device = device_stack[-1] if device_stack else None
|
||
|
|
||
|
# Cached OpDef for this function. When C API is enabled, this is
|
||
|
# the only part of FunctionDef that we cache in Python. When C API
|
||
|
# is disabled the whole _definition is available and this is simply
|
||
|
# another reference to _definition.signature
|
||
|
self._op_def = None
|
||
|
|
||
|
assert isinstance(input_types, (list, tuple))
|
||
|
self._arg_types = input_types
|
||
|
self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i)
|
||
|
for i in range(len(input_types))]
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
"""Function name."""
|
||
|
self._create_definition_if_needed()
|
||
|
return self._func_name
|
||
|
|
||
|
@property
|
||
|
def definition(self):
|
||
|
"""Function definition proto."""
|
||
|
self._create_definition_if_needed()
|
||
|
if self._c_func:
|
||
|
with c_api_util.tf_buffer() as buf:
|
||
|
c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
|
||
|
fdef = function_pb2.FunctionDef()
|
||
|
proto_data = c_api.TF_GetBuffer(buf)
|
||
|
fdef.ParseFromString(compat.as_bytes(proto_data))
|
||
|
return fdef
|
||
|
return self._definition
|
||
|
|
||
|
@property
|
||
|
def _signature(self):
|
||
|
self._create_definition_if_needed()
|
||
|
return self._op_def
|
||
|
|
||
|
def set_grad_func(self, grad_func):
|
||
|
"""Specifies the gradient function of this function."""
|
||
|
assert not self._grad_func
|
||
|
assert isinstance(grad_func, _DefinedFunction)
|
||
|
self._grad_func = grad_func
|
||
|
|
||
|
@property
|
||
|
def grad_func_name(self):
|
||
|
"""Its gradient function's name."""
|
||
|
return self._grad_func.name if self._grad_func else None
|
||
|
|
||
|
@property
|
||
|
def python_grad_func(self):
|
||
|
"""Python gradient function callable."""
|
||
|
return self._python_grad_func
|
||
|
|
||
|
@property
|
||
|
def declared_input_types(self):
|
||
|
"""Returns the list of data types of explicit declared inputs."""
|
||
|
return self._input_types
|
||
|
|
||
|
@property
|
||
|
def captured_inputs(self):
|
||
|
"""Returns the list of implicitly captured inputs."""
|
||
|
self._create_definition_if_needed()
|
||
|
return self._extra_inputs
|
||
|
|
||
|
@property
|
||
|
def stateful_ops(self):
|
||
|
"""Returns the list of stateful ops in function definition.
|
||
|
|
||
|
Returns:
|
||
|
A list of (op.name, op.type) pairs.
|
||
|
"""
|
||
|
self._create_definition_if_needed()
|
||
|
return self._stateful_ops
|
||
|
|
||
|
def _create_definition_if_needed(self):
|
||
|
"""Creates the function definition if it's not created yet."""
|
||
|
with context.graph_mode():
|
||
|
self._create_definition_if_needed_impl()
|
||
|
|
||
|
def _create_definition_if_needed_impl(self):
|
||
|
"""This is not what you want, see _create_definition_if_needed."""
|
||
|
if self._definition is not None or self._c_func is not None:
|
||
|
return
|
||
|
|
||
|
temp_graph = func_graph_from_py_func(
|
||
|
self._func, self._arg_names, self._arg_types, self._func_name,
|
||
|
self._capture_by_value, self._caller_device)
|
||
|
|
||
|
self._extra_inputs = temp_graph.extra_inputs
|
||
|
# pylint: disable=protected-access
|
||
|
self._sub_functions = temp_graph._functions
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
# Extra kwargs are treated as attrs on the function def.
|
||
|
if self._func_name:
|
||
|
base_func_name = self._func_name
|
||
|
else:
|
||
|
base_func_name = _get_func_name(self._func)
|
||
|
if self._grad_func:
|
||
|
base_func_name += ("_%s" % self._grad_func.name)
|
||
|
kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
|
||
|
|
||
|
if not temp_graph._c_graph: # pylint: disable=protected-access
|
||
|
# Build the FunctionDef
|
||
|
self._definition = graph_to_function_def.graph_to_function_def(
|
||
|
temp_graph,
|
||
|
temp_graph.get_operations(),
|
||
|
temp_graph.inputs,
|
||
|
temp_graph.outputs,
|
||
|
out_names=self._out_names)
|
||
|
|
||
|
for k in kwargs_attr:
|
||
|
self._definition.attr[k].CopyFrom(kwargs_attr[k])
|
||
|
|
||
|
# Hash the definition and its dependencies.
|
||
|
self._hash_str = self._create_hash_str(
|
||
|
self._definition.signature.input_arg,
|
||
|
self._definition.signature.output_arg, self._definition.node_def)
|
||
|
|
||
|
# Finally, we decide the function name to use. If not specified,
|
||
|
# make up something which is almost certainly unique (but deterministic).
|
||
|
if not self._func_name:
|
||
|
self._func_name = "_".join([base_func_name, self._hash_str])
|
||
|
self._definition.signature.name = self._func_name
|
||
|
if self._func.__doc__:
|
||
|
self._definition.signature.description = self._func.__doc__
|
||
|
|
||
|
self._op_def = self._definition.signature
|
||
|
else: # C API is enabled
|
||
|
output_names = ([compat.as_bytes(x) for x in self._out_names]
|
||
|
if self._out_names else [])
|
||
|
description = self._func.__doc__ or None
|
||
|
# pylint: disable=protected-access
|
||
|
c_func = c_api.TF_GraphToFunction_wrapper(
|
||
|
temp_graph._c_graph,
|
||
|
base_func_name,
|
||
|
self._func_name is None, # append_hash_to_fn_name
|
||
|
None, # opers
|
||
|
[t._as_tf_output() for t in temp_graph.inputs],
|
||
|
[t._as_tf_output() for t in temp_graph.outputs],
|
||
|
output_names,
|
||
|
None, # opts
|
||
|
description)
|
||
|
self._c_func = c_api_util.ScopedTFFunction(c_func)
|
||
|
# pylint: enable=protected-access
|
||
|
self._set_c_attrs(kwargs_attr)
|
||
|
|
||
|
# Set cached fields: _op_def and _func_name (if not already set)
|
||
|
self._op_def = self.definition.signature
|
||
|
if self._func_name:
|
||
|
assert self._func_name == self._op_def.name
|
||
|
else:
|
||
|
self._func_name = compat.as_str(self._op_def.name)
|
||
|
|
||
|
self._stateful_ops = [(op.name, op.type)
|
||
|
for op in temp_graph.get_operations()
|
||
|
if op.op_def.is_stateful]
|
||
|
|
||
|
def _set_c_attrs(self, attrs):
|
||
|
"""Sets `attrs` as attributes of self._c_func.
|
||
|
|
||
|
Requires that self._c_func is not None.
|
||
|
|
||
|
Args:
|
||
|
attrs: a dictionary from attribute name to attribute proto value
|
||
|
"""
|
||
|
for name, attr_value in attrs.items():
|
||
|
serialized = attr_value.SerializeToString()
|
||
|
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
|
||
|
# It might be worth creating a convenient way to re-use the same status.
|
||
|
c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
|
||
|
serialized)
|
||
|
|
||
|
def _create_hash_str(self, input_arg, output_arg, node_def):
|
||
|
"""Creates an 8-character string unique to this input.
|
||
|
|
||
|
Args:
|
||
|
input_arg: the input_arg field of an OpDef
|
||
|
(e.g. self._definition.signature.input_arg)
|
||
|
output_arg: the output_arg field of an OpDef
|
||
|
(e.g. self._definition.signature.output_arg)
|
||
|
node_def: the node_def field of a FunctionDef
|
||
|
(e.g. self._definition.node_def)
|
||
|
|
||
|
Returns:
|
||
|
The unique string for this input
|
||
|
"""
|
||
|
hasher = hashlib.sha1()
|
||
|
|
||
|
def update_num(n):
|
||
|
hasher.update(compat.as_bytes("%x" % n))
|
||
|
|
||
|
def update_str(s):
|
||
|
update_num(len(s))
|
||
|
hasher.update(compat.as_bytes(s))
|
||
|
|
||
|
def update_strs(slist):
|
||
|
update_num(len(slist))
|
||
|
for s in slist:
|
||
|
update_str(s)
|
||
|
|
||
|
for adef in input_arg:
|
||
|
update_str(adef.SerializeToString())
|
||
|
|
||
|
for adef in output_arg:
|
||
|
update_str(adef.SerializeToString())
|
||
|
|
||
|
for n in sorted(node_def, key=lambda n: n.name):
|
||
|
update_str(n.name)
|
||
|
update_str(n.op)
|
||
|
update_strs(n.input)
|
||
|
update_num(len(n.attr))
|
||
|
# NOTE: protobuf map serialization does not guarantee ordering.
|
||
|
for k in sorted(n.attr):
|
||
|
update_str(k)
|
||
|
update_str(n.attr[k].SerializeToString())
|
||
|
|
||
|
return hasher.hexdigest()[:8]
|
||
|
|
||
|
def add_to_graph(self, g):
|
||
|
"""Adds this function into the graph g."""
|
||
|
self._create_definition_if_needed()
|
||
|
|
||
|
# Adds this function into 'g'.
|
||
|
# pylint: disable=protected-access
|
||
|
if context.executing_eagerly():
|
||
|
context.context().add_function_def(self.definition)
|
||
|
else:
|
||
|
g._add_function(self)
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
# Ensures related sub-routines are defined in 'g', too.
|
||
|
for f in self._sub_functions.values():
|
||
|
f.add_to_graph(g)
|
||
|
|
||
|
# Adds its gradient function, too.
|
||
|
if self._grad_func:
|
||
|
self._grad_func.add_to_graph(g)
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
self.add_to_graph(ops.get_default_graph())
|
||
|
args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
|
||
|
ret, op = _call(self._signature, *args, **kwargs)
|
||
|
|
||
|
# Set a hidden attr in 'op' so that gradients_impl can refer back
|
||
|
# to this _DefinedFunction instance to access python_grad_func.
|
||
|
assert isinstance(op, ops.Operation)
|
||
|
setattr(op, "__defun", self)
|
||
|
|
||
|
if self._shape_func is not None:
|
||
|
shapes = self._shape_func(op)
|
||
|
if len(shapes) != len(op.outputs):
|
||
|
raise ValueError("shape_func produced %d shapes for %d outputs" %
|
||
|
(len(shapes), len(op.outputs)))
|
||
|
for (t, shape) in zip(op.outputs, shapes):
|
||
|
t.set_shape(shape)
|
||
|
return ret
|
||
|
|
||
|
|
||
|
class _OverloadedFunction(object):
|
||
|
"""_OverloadedFunction encapsulates an overloaded function.
|
||
|
|
||
|
_OverloadedFunction maintains a mapping from input types to
|
||
|
instantiated _DefinedFunction in self._overload.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
func,
|
||
|
argnames,
|
||
|
func_name=None,
|
||
|
grad_func=None,
|
||
|
python_grad_func=None,
|
||
|
out_names=None,
|
||
|
**kwargs):
|
||
|
"""Creates _DefinedFunction.
|
||
|
|
||
|
Args:
|
||
|
func: A python callable which constructs a tf function body.
|
||
|
argnames: A list of strings for function argument names.
|
||
|
func_name: The function name. Defaults to None, in which derives from
|
||
|
'func'.
|
||
|
grad_func: This function's gradient function, if not None. Defaults
|
||
|
to None.
|
||
|
python_grad_func: A python callable implementing the gradient of
|
||
|
the function python-side.
|
||
|
out_names: A list of strings for the function return value names.
|
||
|
**kwargs: The keyword arguments. **kwargs is passed to every call
|
||
|
site of this function.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: The function definition is invalid.
|
||
|
|
||
|
"""
|
||
|
self._func = func
|
||
|
self._argnames = argnames
|
||
|
self._func_name = func_name
|
||
|
assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
|
||
|
self._grad_func = grad_func
|
||
|
self._python_grad_func = python_grad_func
|
||
|
self._out_names = out_names
|
||
|
self._extra_kwargs = kwargs
|
||
|
self._overload = {}
|
||
|
|
||
|
def instantiate(self, input_types):
|
||
|
"""Instantiate this function given input argument types.
|
||
|
|
||
|
Args:
|
||
|
input_types: A list of data types for the inputs.
|
||
|
|
||
|
Returns:
|
||
|
_DefinedFunction for the given input types.
|
||
|
|
||
|
"""
|
||
|
# Stringify the type list.
|
||
|
key = _type_list_to_str(input_types)
|
||
|
defined = self._overload.get(key)
|
||
|
if not defined:
|
||
|
# If not defined yet, define the function given the input types.
|
||
|
name = self._func_name
|
||
|
if name is not None:
|
||
|
name = "_".join([name, key])
|
||
|
defined = _DefinedFunction(
|
||
|
self._func,
|
||
|
self._argnames,
|
||
|
input_types,
|
||
|
name,
|
||
|
None,
|
||
|
self._python_grad_func,
|
||
|
out_names=self._out_names,
|
||
|
**self._extra_kwargs)
|
||
|
_ = defined.name # Fully instantiate the function definition.
|
||
|
if self._grad_func:
|
||
|
# If _grad_func is given, it is another
|
||
|
# _OverloadedFunction. We need to instantiate it with the
|
||
|
# right input types.
|
||
|
output_types = [
|
||
|
dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access
|
||
|
]
|
||
|
# pylint: disable=protected-access
|
||
|
defined._grad_func = self._grad_func.instantiate(input_types +
|
||
|
output_types)
|
||
|
# pylint: enable=protected-access
|
||
|
self._overload[key] = defined
|
||
|
return defined
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
input_types = []
|
||
|
args = list(args)
|
||
|
for (i, x) in enumerate(args):
|
||
|
x = ops.convert_to_tensor(x)
|
||
|
if not isinstance(x, ops.Tensor):
|
||
|
raise ValueError("Expect a Tensor but get ", x)
|
||
|
input_types.append(x.dtype)
|
||
|
args[i] = x
|
||
|
return self.instantiate(input_types)(*args, **kwargs)
|
||
|
|
||
|
|
||
|
class _FuncGraph(ops.Graph):
|
||
|
"""A helper for constructing a function.
|
||
|
|
||
|
_FuncGraph overrides ops.Graph's create_op() so that we can keep
|
||
|
track of all inputs into every op created inside the function. If
|
||
|
any input is from other graphs, we keep track of it in self.capture
|
||
|
and substitute the input with a place holder.
|
||
|
|
||
|
Each captured input's corresponding place holder is converted into a
|
||
|
function argument and the caller passes in the captured tensor.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, name, capture_by_value, *args, **kwargs):
|
||
|
super(_FuncGraph, self).__init__(*args, **kwargs)
|
||
|
self._capture_by_value = capture_by_value
|
||
|
self._building_function = True
|
||
|
self._outer_graph = ops.get_default_graph()
|
||
|
self._vscope = vs.get_variable_scope()
|
||
|
self._old_custom_getter = self._vscope.custom_getter
|
||
|
|
||
|
# The name of the function.
|
||
|
self.name = name
|
||
|
# Placeholder tensors representing the inputs to this function. The tensors
|
||
|
# are in this _FuncGraph.
|
||
|
self.inputs = []
|
||
|
# Tensors that will be returned this function. The tensors are in this
|
||
|
# _FuncGraph.
|
||
|
self.outputs = []
|
||
|
# Maps external tensor -> internal tensor (e.g. input placeholder).
|
||
|
self._captured = {}
|
||
|
# The external tensors that have been captured as inputs and must be passed
|
||
|
# to this function (empty if capturing by value, otherwise these are the
|
||
|
# keys of _captured).
|
||
|
self.extra_inputs = []
|
||
|
# Input placeholders that been added for captured values (empty if capturing
|
||
|
# by value).
|
||
|
self.extra_args = []
|
||
|
# Captured variables.
|
||
|
# TODO(skyewm): is this needed?
|
||
|
self.extra_vars = []
|
||
|
|
||
|
# pylint: disable=g-doc-return-or-yield
|
||
|
|
||
|
@tf_contextlib.contextmanager
|
||
|
def container(self, container_name):
|
||
|
"""Returns a context manager that specifies the resource container to use.
|
||
|
|
||
|
Overridden from @{tf.Graph} to update both the init_scope container
|
||
|
and the present inner container. This is necessary to make sure setting
|
||
|
containers applies correctly both to created variables and to stateful
|
||
|
ops.
|
||
|
|
||
|
Args:
|
||
|
container_name: container name string.
|
||
|
|
||
|
Returns:
|
||
|
A context manager for defining resource containers for stateful ops,
|
||
|
yields the container name.
|
||
|
"""
|
||
|
original_container = self._container
|
||
|
# pylint: disable=protected-access
|
||
|
with ops.init_scope():
|
||
|
original_init_container = ops.get_default_graph()._container
|
||
|
try:
|
||
|
self._container = container_name
|
||
|
with ops.init_scope():
|
||
|
ops.get_default_graph()._container = container_name
|
||
|
yield self._container
|
||
|
finally:
|
||
|
self._container = original_container
|
||
|
with ops.init_scope():
|
||
|
ops.get_default_graph()._container = original_init_container
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
# pylint: enable=g-doc-return-or-yield
|
||
|
|
||
|
def getvar(
|
||
|
self,
|
||
|
getter,
|
||
|
name,
|
||
|
shape=None,
|
||
|
dtype=None,
|
||
|
initializer=None,
|
||
|
reuse=None,
|
||
|
trainable=True,
|
||
|
collections=None, # pylint: disable=redefined-outer-name
|
||
|
use_resource=None,
|
||
|
**kwargs):
|
||
|
"""A custom variable getter."""
|
||
|
# Here, we switch the default graph to the outer graph and ask the
|
||
|
# variable scope in which the function is defined to give us the
|
||
|
# variable. The variable is stashed in extra_vars and returned to
|
||
|
# the caller.
|
||
|
#
|
||
|
# We capture these variables so that the variable definition is
|
||
|
# hoisted upward to the outer most graph.
|
||
|
with self._outer_graph.as_default():
|
||
|
# pylint: disable=protected-access
|
||
|
var = self._vscope.get_variable(
|
||
|
vs._get_default_variable_store(),
|
||
|
name,
|
||
|
shape=shape,
|
||
|
dtype=dtype,
|
||
|
initializer=initializer,
|
||
|
reuse=reuse,
|
||
|
trainable=trainable,
|
||
|
collections=collections,
|
||
|
use_resource=use_resource)
|
||
|
self.extra_vars.append(var)
|
||
|
if isinstance(var, resource_variable_ops.ResourceVariable):
|
||
|
# For resource-based variables read the variable outside the function
|
||
|
# and pass in the value. This ensures that the function is pure and
|
||
|
# differentiable. TODO(apassos) this may have performance problems if
|
||
|
# the function will only do embedding lookups on the variable.
|
||
|
return var.value()
|
||
|
return var
|
||
|
|
||
|
def create_op(self, op_type, inputs, data_types, **kwargs):
|
||
|
for i, x in enumerate(inputs):
|
||
|
if isinstance(x, ops.EagerTensor) or x.graph is not self:
|
||
|
inputs[i] = self.capture(x)
|
||
|
return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
|
||
|
**kwargs)
|
||
|
|
||
|
def capture(self, tensor, name=None):
|
||
|
"""Adds the given tensor to this graph and returns the captured tensor."""
|
||
|
if tensor in self._captured:
|
||
|
# Captured already.
|
||
|
return self._captured[tensor]
|
||
|
elif self._capture_by_value:
|
||
|
return self._add_tensor_and_parents(tensor)
|
||
|
else:
|
||
|
return self._capture_tensor_as_extra_input(tensor, name)
|
||
|
|
||
|
def _capture_tensor_as_extra_input(self, tensor, name=None):
|
||
|
# Substitute with a placeholder.
|
||
|
self.extra_inputs.append(tensor)
|
||
|
# Hoist the new input placeholder out of any control flow context
|
||
|
# we're currently in.
|
||
|
with ops.control_dependencies(None):
|
||
|
ph = array_ops.placeholder(
|
||
|
tensor.dtype, shape=tensor.get_shape(), name=name)
|
||
|
# pylint: disable=protected-access
|
||
|
if ops._USE_C_SHAPES:
|
||
|
if isinstance(tensor, ops.EagerTensor):
|
||
|
handle_data = tensor._handle_data
|
||
|
if handle_data:
|
||
|
handle_data = handle_data.SerializeToString()
|
||
|
else:
|
||
|
handle_data = c_api.GetResourceHandleShapeAndType(
|
||
|
tensor.graph._c_graph, tensor._as_tf_output())
|
||
|
|
||
|
if handle_data:
|
||
|
c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
|
||
|
ph._as_tf_output(),
|
||
|
compat.as_bytes(handle_data))
|
||
|
else:
|
||
|
ph._handle_data = tensor._handle_data
|
||
|
# pylint: enable=protected-access
|
||
|
self.inputs.append(ph)
|
||
|
self._captured[tensor] = ph
|
||
|
self.extra_args.append(ph)
|
||
|
if _is_guaranteed_const(tensor):
|
||
|
with ops.control_dependencies(None):
|
||
|
return array_ops.guarantee_const(ph)
|
||
|
else:
|
||
|
return ph
|
||
|
|
||
|
def _add_tensor_and_parents(self, tensor):
|
||
|
op = self._add_op_and_parents(tensor.op)
|
||
|
return op.outputs[tensor.value_index]
|
||
|
|
||
|
def _add_op_and_parents(self, op):
|
||
|
# pylint: disable=protected-access
|
||
|
op_def = graph_to_function_def._get_op_def(op)
|
||
|
# pylint: enable=protected-access
|
||
|
if op_def.is_stateful:
|
||
|
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
|
||
|
"by value." % (op.name, op.type))
|
||
|
elif op.type in ("Placeholder", "PlaceholderV2"):
|
||
|
raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
|
||
|
"by value." % (op.name, op.type))
|
||
|
|
||
|
captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
|
||
|
|
||
|
captured_op = self.create_op(
|
||
|
op.type,
|
||
|
captured_inputs, [o.dtype for o in op.outputs],
|
||
|
name=op.name,
|
||
|
attrs=op.node_def.attr,
|
||
|
op_def=op_def)
|
||
|
|
||
|
for t, captured_t in zip(op.outputs, captured_op.outputs):
|
||
|
self._captured[t] = captured_t
|
||
|
|
||
|
return captured_op
|
||
|
|
||
|
|
||
|
def func_graph_from_py_func(func, arg_names, arg_types, name=None,
|
||
|
capture_by_value=False, device=None,
|
||
|
colocation_stack=None, container=None,
|
||
|
collections_ref=None):
|
||
|
"""Returns a _FuncGraph generated from `func`.
|
||
|
|
||
|
Args:
|
||
|
func: A Python callable which constructs a TF function body. The arguments
|
||
|
must correspond to `arg_types`. Returns a value or list/tuple of values.
|
||
|
No returned value can be None.
|
||
|
arg_names: A sequence of strings for the function argument names.
|
||
|
arg_types: A sequence of the function's argument types.
|
||
|
name: The function name. If None, the name is derived from `func`.
|
||
|
capture_by_value: boolean. If True, captured values will be copied into the
|
||
|
function body.
|
||
|
device: device name or function.
|
||
|
colocation_stack: A colocation stack (list) the _FuncGraph should use.
|
||
|
container: A container name the _FuncGraph should start with.
|
||
|
collections_ref: A reference to a collections dict the _FuncGraph should
|
||
|
use internally.
|
||
|
|
||
|
Returns:
|
||
|
A _FuncGraph.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if func returns None.
|
||
|
"""
|
||
|
if not name:
|
||
|
name = _get_func_name(func)
|
||
|
func_graph = _FuncGraph(name, capture_by_value)
|
||
|
|
||
|
with func_graph.as_default(), ops.device(device):
|
||
|
# pylint: disable=protected-access
|
||
|
if collections_ref is not None:
|
||
|
func_graph._collections = collections_ref
|
||
|
if container is not None:
|
||
|
func_graph._container = container
|
||
|
if colocation_stack is not None:
|
||
|
func_graph._colocation_stack = colocation_stack
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
# Create placeholders for the function arguments.
|
||
|
for (argname, argtype) in zip(arg_names, arg_types):
|
||
|
argholder = array_ops.placeholder(argtype, name=argname)
|
||
|
func_graph.inputs.append(argholder)
|
||
|
# Call func and gather the output tensors.
|
||
|
with vs.variable_scope("", custom_getter=func_graph.getvar):
|
||
|
outputs = func(*func_graph.inputs)
|
||
|
|
||
|
# There is no way of distinguishing between a function not returning
|
||
|
# anything and a function returning None in Python.
|
||
|
# We need to allow the former and ideally want to forbid the latter as
|
||
|
# it is most likely user error.
|
||
|
# TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
|
||
|
# allow users to explicitly mark the function as not returning anything.
|
||
|
# For now, we allow a single None return and interpret it as a function
|
||
|
# with no output.
|
||
|
if outputs is None:
|
||
|
outputs = []
|
||
|
else:
|
||
|
# If func only returned one value, make it a tuple.
|
||
|
if not isinstance(outputs, (list, tuple)):
|
||
|
outputs = (outputs,)
|
||
|
if any([_ is None for _ in outputs]):
|
||
|
raise ValueError("Function can not return None.")
|
||
|
# Ensures each output is a Tensor in the function graph.
|
||
|
outputs = [ops.convert_to_tensor(t) for t in outputs]
|
||
|
outputs = [func_graph.capture(t) if t.graph is not func_graph else t
|
||
|
for t in outputs]
|
||
|
func_graph.outputs = outputs
|
||
|
return func_graph
|
||
|
|
||
|
|
||
|
def _is_guaranteed_const(tensor):
|
||
|
"""Determines whether `tensor` is guaranteed to be a constant.
|
||
|
|
||
|
A tensor is guaranteed to be a constant if either it was produced by
|
||
|
a `GuaranteeConst` op or if all of its children are guaranteed to be
|
||
|
constants.
|
||
|
|
||
|
Args:
|
||
|
tensor: The tensor for which to determine const-ness.
|
||
|
|
||
|
Returns:
|
||
|
True if `tensor` is guaranteed to be a constant, False otherwise.
|
||
|
"""
|
||
|
|
||
|
if isinstance(tensor, ops.EagerTensor):
|
||
|
return False
|
||
|
|
||
|
class Work(object):
|
||
|
|
||
|
def __init__(self, op, leaving):
|
||
|
self.op = op
|
||
|
self.leaving = leaving
|
||
|
|
||
|
is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst"
|
||
|
constants = set([])
|
||
|
def all_inputs_const(op):
|
||
|
# If all inputs of an op are guaranteed constants, then we can infer that
|
||
|
# the op produces a constant as well.
|
||
|
return op.inputs and all(inp.op in constants for inp in op.inputs)
|
||
|
|
||
|
visited = set([])
|
||
|
stack = [Work(tensor.op, leaving=False)]
|
||
|
while stack:
|
||
|
work = stack.pop()
|
||
|
if work.leaving:
|
||
|
if all_inputs_const(work.op):
|
||
|
constants.add(work.op)
|
||
|
continue
|
||
|
visited.add(work.op)
|
||
|
if is_guaranteed_const(work.op):
|
||
|
constants.add(work.op)
|
||
|
continue
|
||
|
|
||
|
# This op will be revisited after all its inputs are checked for const-ness.
|
||
|
stack.append(Work(work.op, leaving=True))
|
||
|
for inp in work.op.inputs:
|
||
|
if inp.op not in visited:
|
||
|
stack.append(Work(inp.op, leaving=False))
|
||
|
return tensor.op in constants
|
||
|
|
||
|
|
||
|
def _call(sig, *inputs, **kwargs):
|
||
|
"""Adds a node calling a function.
|
||
|
|
||
|
This adds a `call` op to the default graph that calls the function
|
||
|
of signature `sig`, passing the tensors in `inputs` as arguments.
|
||
|
It returns the outputs of the call, which are one or more tensors.
|
||
|
|
||
|
`sig` is OpDefArg.a `_DefinedFunction` object.
|
||
|
|
||
|
You can pass an optional keyword parameter `name=string` to name the
|
||
|
added operation.
|
||
|
|
||
|
You can pass an optional keyword parameter `noinline=True|False` to
|
||
|
instruct the runtime not to inline the function body into the call
|
||
|
site.
|
||
|
|
||
|
Args:
|
||
|
sig: OpDefArg. The signature of the function.
|
||
|
*inputs: arguments to the function.
|
||
|
**kwargs: Optional keyword arguments. Can only contain 'name' or
|
||
|
'noinline'.
|
||
|
|
||
|
Returns:
|
||
|
A 2-element tuple. First element: a Tensor if the function returns a single
|
||
|
value; a list of Tensors if the function returns multiple value; the
|
||
|
Operation if the function returns no values. Second element: the Operation.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if the arguments are invalid.
|
||
|
"""
|
||
|
if len(inputs) != len(sig.input_arg):
|
||
|
raise ValueError("Expected number of arguments: %d, received: %d" % (len(
|
||
|
sig.input_arg), len(inputs)))
|
||
|
name = kwargs.pop("name", None)
|
||
|
g = ops.get_default_graph()
|
||
|
func_name = sig.name
|
||
|
attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
|
||
|
output_types = [dtypes.DType(x.type) for x in sig.output_arg]
|
||
|
with ops.name_scope(name, func_name, inputs) as name:
|
||
|
op = g.create_op(
|
||
|
func_name,
|
||
|
list(inputs),
|
||
|
output_types,
|
||
|
name=name,
|
||
|
attrs=attrs,
|
||
|
op_def=sig,
|
||
|
compute_shapes=False)
|
||
|
if op.outputs:
|
||
|
if len(op.outputs) == 1:
|
||
|
ret = op.outputs[0]
|
||
|
else:
|
||
|
ret = tuple(op.outputs)
|
||
|
else:
|
||
|
ret = op
|
||
|
return ret, op
|
||
|
|
||
|
|
||
|
def _from_definition(fdef, grad_func=None):
|
||
|
"""Creates a _DefinedFunction initialized from a FunctionDef proto.
|
||
|
|
||
|
Args:
|
||
|
fdef: a FunctionDef
|
||
|
grad_func: a _DefinedFunction or None
|
||
|
|
||
|
Returns:
|
||
|
A _DefinedFunction representing fdef
|
||
|
"""
|
||
|
# TODO(iga): This method does major surgery on _DefinedFunction.
|
||
|
# Make it a named constructor using @classmethod of _DefinedFunction.
|
||
|
|
||
|
# The Python callable is only needed to create a FunctionDef. Since we have
|
||
|
# the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
|
||
|
# have access to such a callable here).
|
||
|
func = None
|
||
|
argnames = [arg.name for arg in fdef.signature.input_arg]
|
||
|
input_types = tuple(
|
||
|
dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
|
||
|
func_name = fdef.signature.name
|
||
|
# Note: FunctionDefs do not include python gradient functions, so if the
|
||
|
# original _DefinedFunction included one it will not be reflected here.
|
||
|
python_grad_func = None
|
||
|
out_names = [arg.name for arg in fdef.signature.output_arg]
|
||
|
result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
|
||
|
python_grad_func, out_names)
|
||
|
# pylint: disable=protected-access
|
||
|
if ops._USE_C_API:
|
||
|
serialized = fdef.SerializeToString()
|
||
|
c_func = c_api.TF_FunctionImportFunctionDef(serialized)
|
||
|
result._c_func = c_api_util.ScopedTFFunction(c_func)
|
||
|
result._extra_inputs = []
|
||
|
else:
|
||
|
result._definition = fdef
|
||
|
# Captured inputs are added as regular inputs to a function when it's
|
||
|
# serialized, i.e. any extra inputs from the original function are now
|
||
|
# included in `result`._args
|
||
|
result._extra_inputs = []
|
||
|
result._hash_str = result._create_hash_str(
|
||
|
result._definition.signature.input_arg,
|
||
|
result._definition.signature.output_arg, result._definition.node_def)
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
def _from_library(lib):
|
||
|
"""Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
|
||
|
|
||
|
This method handles assigning the correct gradient functions to each
|
||
|
function.
|
||
|
|
||
|
Args:
|
||
|
lib: a FunctionDefLibrary
|
||
|
|
||
|
Returns:
|
||
|
A list of _DefinedFunctions
|
||
|
|
||
|
Raises:
|
||
|
ValueError: `lib` is invalid
|
||
|
"""
|
||
|
if not lib.function and not lib.gradient:
|
||
|
return []
|
||
|
|
||
|
# function name -> FunctionDef proto
|
||
|
funcs = {fdef.signature.name: fdef for fdef in lib.function}
|
||
|
|
||
|
# Validate that all references function names have function defs
|
||
|
for g in lib.gradient:
|
||
|
if g.function_name not in funcs:
|
||
|
raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
|
||
|
(g.function_name, str(lib)))
|
||
|
if g.gradient_func not in funcs:
|
||
|
raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
|
||
|
(g.gradient_func, str(lib)))
|
||
|
|
||
|
# function name -> gradient function name
|
||
|
func_to_grad = collections.defaultdict(lambda: None)
|
||
|
# gradient function name -> names of functions having that grad function
|
||
|
grad_to_funcs = collections.defaultdict(list)
|
||
|
|
||
|
for gdef in lib.gradient:
|
||
|
func_to_grad[gdef.function_name] = gdef.gradient_func
|
||
|
grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
|
||
|
|
||
|
# Start with functions without gradients
|
||
|
ready = [
|
||
|
fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
|
||
|
]
|
||
|
if not ready:
|
||
|
raise ValueError(
|
||
|
"FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib))
|
||
|
# function name -> _DefinedFunction
|
||
|
initialized = {}
|
||
|
|
||
|
while ready:
|
||
|
fdef = ready.pop()
|
||
|
name = fdef.signature.name
|
||
|
|
||
|
grad = initialized.get(func_to_grad[name])
|
||
|
if func_to_grad[name]:
|
||
|
assert grad
|
||
|
defined_func = _from_definition(fdef, grad_func=grad)
|
||
|
initialized[name] = defined_func
|
||
|
|
||
|
ready.extend(funcs[f] for f in grad_to_funcs[name])
|
||
|
|
||
|
return initialized.values()
|
||
|
|
||
|
|
||
|
def _parse_kwargs_as_attrs(func_name, **kwargs):
|
||
|
"""Parses **kwargs into a node's attributes."""
|
||
|
attrs = {}
|
||
|
|
||
|
noinline = kwargs.pop("noinline", None)
|
||
|
if noinline is not None:
|
||
|
attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
|
||
|
|
||
|
compiled = kwargs.pop("compiled", None)
|
||
|
separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
|
||
|
if compiled is not None:
|
||
|
attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
|
||
|
attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
|
||
|
b=bool(separate_compiled_gradients))
|
||
|
# Forward _XlaScope from enclosing context (if set), otherwise create new.
|
||
|
# pylint: disable=protected-access
|
||
|
if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
|
||
|
attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
|
||
|
else:
|
||
|
attrs["_XlaScope"] = attr_value_pb2.AttrValue(
|
||
|
s=("function_%s" % func_name).encode())
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
kwargs_keys = list(kwargs.keys())
|
||
|
for key in kwargs_keys:
|
||
|
if key.startswith("experimental_"):
|
||
|
attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
|
||
|
del kwargs[key]
|
||
|
|
||
|
if kwargs:
|
||
|
raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
|
||
|
return attrs
|
||
|
|
||
|
|
||
|
def _get_func_name(func):
|
||
|
_, func = tf_decorator.unwrap(func)
|
||
|
if callable(func):
|
||
|
if tf_inspect.isfunction(func):
|
||
|
return func.__name__
|
||
|
elif tf_inspect.ismethod(func):
|
||
|
return "%s.%s" % (func.__self__.__name__, func.__name__)
|
||
|
else: # Probably a class instance with __call__
|
||
|
return type(func)
|
||
|
else:
|
||
|
raise ValueError("Argument must be callable")
|
||
|
|
||
|
|
||
|
def get_extra_vars():
|
||
|
"""Returns the captured variables by the function.
|
||
|
|
||
|
Returns:
|
||
|
If the default graph is being used to define a function, the
|
||
|
returned list of variables are those created inside the function
|
||
|
body so far. Otherwise, returns an empty list.
|
||
|
"""
|
||
|
g = ops.get_default_graph()
|
||
|
if isinstance(g, _FuncGraph):
|
||
|
return g.extra_vars
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
|
||
|
def get_extra_inputs():
|
||
|
"""Returns the captured input tensors by the function.
|
||
|
|
||
|
Returns:
|
||
|
If the default graph is being used to define a function, the
|
||
|
returned list of tensors are those accessed inside the function body
|
||
|
but defined outside the function body so far. Otherwise, returns an
|
||
|
empty list.
|
||
|
"""
|
||
|
g = ops.get_default_graph()
|
||
|
if isinstance(g, _FuncGraph):
|
||
|
return g.extra_inputs
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
|
||
|
def get_extra_args():
|
||
|
"""Returns the corresponding function arguments for the captured inputs.
|
||
|
|
||
|
Returns:
|
||
|
If the default graph is being used to define a function, the
|
||
|
returned list of place holders are those used inside the function
|
||
|
body corresponding those returned by get_extra_inputs(). Otherwise,
|
||
|
returns an empty list.
|
||
|
"""
|
||
|
g = ops.get_default_graph()
|
||
|
if isinstance(g, _FuncGraph):
|
||
|
return g.extra_args
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
|
||
|
def _type_list_to_str(types):
|
||
|
if any([_ not in _DTYPE_TO_STR for _ in types]):
|
||
|
raise ValueError("Unsupported dtypes: %s" % types)
|
||
|
return "".join([_DTYPE_TO_STR[_] for _ in types])
|
||
|
|
||
|
|
||
|
# NOTE: The list needs to be extended when more data types are added.
|
||
|
_DTYPE_TO_STR = {
|
||
|
dtypes.float16: "f16",
|
||
|
dtypes.float32: "f32",
|
||
|
dtypes.float64: "f64",
|
||
|
dtypes.int32: "i32",
|
||
|
dtypes.uint8: "i8",
|
||
|
dtypes.uint16: "u16",
|
||
|
dtypes.uint32: "u32",
|
||
|
dtypes.uint64: "u64",
|
||
|
dtypes.int16: "i16",
|
||
|
dtypes.int8: "i8",
|
||
|
dtypes.string: "s",
|
||
|
dtypes.complex64: "c64",
|
||
|
dtypes.complex128: "c128",
|
||
|
dtypes.int64: "i64",
|
||
|
dtypes.bool: "b",
|
||
|
dtypes.qint8: "qi8",
|
||
|
dtypes.quint8: "qu8",
|
||
|
dtypes.qint16: "qi16",
|
||
|
dtypes.quint16: "qu16",
|
||
|
dtypes.qint32: "qi32",
|
||
|
dtypes.bfloat16: "b16"
|
||
|
}
|
||
|
|
||
|
|
||
|
def function_def_from_tf_function(c_func):
|
||
|
"""Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
|
||
|
with c_api_util.tf_buffer() as buf:
|
||
|
c_api.TF_FunctionToFunctionDef(c_func, buf)
|
||
|
data = c_api.TF_GetBuffer(buf)
|
||
|
fdef = function_pb2.FunctionDef()
|
||
|
fdef.ParseFromString(compat.as_bytes(data))
|
||
|
return fdef
|