# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Core conversion logic, serves as main point of access.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import imp import gast from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.converters import asserts from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.converters import call_trees from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import decorators from tensorflow.contrib.autograph.converters import error_handlers from tensorflow.contrib.autograph.converters import ifexp from tensorflow.contrib.autograph.converters import lists from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.util import tf_inspect # TODO(mdan): Might we not need any renaming at all? def is_whitelisted_for_graph(o): """Check whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. Returns: Boolean """ m = tf_inspect.getmodule(o) for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): return True return False def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): # TODO(mdan): This is not a reliable mechanism. # The most reliable way is to check the source code, the AST will contain # a Lambda node instead of a FunctionDef if o.__name__ == '': raise NotImplementedError( 'lambda functions are not yet supported; declare the function' ' using def instead: %s' % o) else: node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) program_ctx.add_to_cache(o, node) if program_ctx.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} bases = [] for base in c.__bases__: if isinstance(object, base): bases.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) bases.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. output_nodes.append( gast.ClassDef( class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[])) node = gast.Module(output_nodes) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. node = qual_names.resolve(node) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) node = ast_util.rename_symbols(node, renames) return node, class_name, class_namespace def _add_reserved_symbol(namespace, name, entity): if name not in namespace: namespace[name] = entity elif namespace[name] != entity: raise ValueError('The name "%s" is reserved and may not be used.' % name) ag_internal = None def _add_self_references(namespace, autograph_module): """Adds namespace references to the module that exposes the api itself.""" global ag_internal if ag_internal is None: # Craft a module that exposes parts of the external API as well as certain # internal modules. ag_internal = imp.new_module('autograph') ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils ag_internal.rewrite_graph_construction_error = ( errors.rewrite_graph_construction_error) # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be # accessible as ag__. ag_internal.__dict__.update(operators.__dict__) _add_reserved_symbol(namespace, 'ag__', ag_internal) def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return node, new_name, namespace def _apply_transformer(node, context, converter_module): # TODO(mdan): Clear static analysis here. node = qual_names.resolve(node) node = activity.resolve(node, context.info, None) node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) node = type_info.resolve(node, context.info) node = converter_module.transform(node, context) return node def node_to_graph(node, context): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ # TODO(mdan): Verify arguments for correctness. node = _apply_transformer(node, context, ifexp) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? context.info.source_code = None node = _apply_transformer(node, context, decorators) node = _apply_transformer(node, context, break_statements) node = _apply_transformer(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = _apply_transformer(node, context, continue_statements) context.info.namespace['len'] = len node = _apply_transformer(node, context, single_return) node = _apply_transformer(node, context, lists) node = _apply_transformer(node, context, slices) node = _apply_transformer(node, context, builtin_functions) node = _apply_transformer(node, context, call_trees) node = _apply_transformer(node, context, control_flow) node = _apply_transformer(node, context, logical_expressions) node = _apply_transformer(node, context, side_effect_guards) node = _apply_transformer(node, context, name_scopes) node = _apply_transformer(node, context, error_handlers) return node