# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Public API.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from functools import wraps from enum import Enum # pylint:disable=g-bad-import-order import gast import six # pylint:enable=g-bad-import-order from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect # TODO(mdan): Properly document the type hints. # TODO(mdan): Reduce the type hint information to (module, type). # (currently we require (module + class name, type)) def convert(recursive=False, verbose=False, arg_types=None): """Decorator that compiles a function to graph mode. The decorator is dynamic - invoking compilation whenever the decorated function is called. This means the parameter values are known at compilation. Args: recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_types: See to_graph. Returns: A decorator that compiles the given function to graph mode. Raises: ValueError: If any of the arguments are illegal. """ if arg_types is None: arg_types = {} def decorator(f): """Decorator implementation.""" @wraps(f) def wrapper(*args, **kwargs): return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) wrapper = tf_decorator.make_decorator(f, wrapper) # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. setattr(wrapper, '__pyct_is_compile_decorator', True) return wrapper return decorator class RunMode(Enum): GRAPH = 1 PY_FUNC = 2 def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): """Decorator that suppresses compilation of a function. Args: run_as: RunMode value. Whether to run the function as-is, or wrap it into a py_func. return_dtypes: See autograph.utils.py_func.wrap_py_func. Setting to None or empty list or tuple will create a dummy return value that can be used to set control dependencies. Returns: A decorator that wraps the original function. """ def decorator(f): """Decorator implementation.""" @wraps(f) def graph_wrapper(*args, **kwargs): return f(*args, **kwargs) @wraps(f) def py_func_wrapper(*args, **kwargs): if kwargs: raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') # TODO(mdan): Add support for kwargs. return py_func.wrap_py_func( f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) if run_as == RunMode.GRAPH: wrapper = graph_wrapper elif run_as == RunMode.PY_FUNC: wrapper = py_func_wrapper else: raise ValueError('unknown value for run_as: %s' % run_as) # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. setattr(wrapper, '__pyct_is_compile_decorator', True) return wrapper return decorator def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): """Compiles a function call inline.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): return builtins.dynamic_builtin(f, *args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f effective_args = args f_class = inspect_utils.getmethodclass(f) if f_class is not None: partial_types = (f_class,) else: partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f,) + args partial_types = (f.__class__,) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) for name, arg in arg_values.items(): if arg is unknown_arg_value: continue arg_class = arg.__class__ # If arg_value_hints specifies any name, use that instead. if name not in arg_types: arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__,) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'],) converted_f = to_graph( target_entity, recursive=recursive, verbose=verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types) return converted_f(*effective_args, **kwargs) def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. Raises: ValueError: If the converted function defines or refers to symbol names that are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object( module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. source_map_attribute_name = 'ag_source_map' if getattr(compiled_fn, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled_fn, source_map_attribute_name)) setattr(compiled_fn, source_map_attribute_name, compiled_node.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: e: A Python entity. recursive: See to_graph. arg_values: See to_graph. arg_types: See to_graph. partial_types: See to_graph. indentation: String, when to use for each level of indentation. Returns: String. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(dep, indentation)[0] for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code