158 lines
6 KiB
Python
158 lines
6 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Converting AST to code.
|
|
|
|
Adapted from Tangent.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
# TODO(mdan): Use six for compatibility here.
|
|
import atexit
|
|
import imp
|
|
import os
|
|
import tempfile
|
|
|
|
import astor
|
|
import gast
|
|
|
|
from tensorflow.contrib.autograph.pyct import anno
|
|
from tensorflow.contrib.autograph.pyct import ast_util
|
|
from tensorflow.contrib.autograph.pyct import origin_info
|
|
from tensorflow.contrib.autograph.pyct import parser
|
|
|
|
|
|
def _build_source_map(node, code):
|
|
"""Return the Python objects represented by given AST.
|
|
|
|
Compiling the AST code this way ensures that the source code is readable by
|
|
e.g. `pdb` or `inspect`.
|
|
|
|
Args:
|
|
node: An AST node of the original generated code, before the source code is
|
|
generated.
|
|
code: The string representation of the source code for the newly generated
|
|
code.
|
|
|
|
Returns:
|
|
Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
|
|
generated code.
|
|
"""
|
|
# After we have the final generated code we reparse it to get the final line
|
|
# numbers. Then we walk through the generated and original ASTs in parallel
|
|
# to build the mapping between the user and generated code.
|
|
new_node = parser.parse_str(code)
|
|
origin_info.resolve(new_node, code)
|
|
source_mapping = {}
|
|
for before, after in ast_util.parallel_walk(node, new_node):
|
|
# Need both checks because if origin information is ever copied over to new
|
|
# nodes then we need to rely on the fact that only the original user code
|
|
# has the origin annotation.
|
|
if (anno.hasanno(before, anno.Basic.ORIGIN) and
|
|
anno.hasanno(after, anno.Basic.ORIGIN)):
|
|
source_info = anno.getanno(before, anno.Basic.ORIGIN)
|
|
new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
|
|
source_mapping[new_line_number] = source_info
|
|
return source_mapping
|
|
|
|
|
|
def ast_to_source(node, indentation=' '):
|
|
"""Return the source code of given AST."""
|
|
original_node = node
|
|
if isinstance(node, gast.AST):
|
|
node = gast.gast_to_ast(node)
|
|
generator = astor.codegen.SourceGenerator(indentation, False,
|
|
astor.string_repr.pretty_string)
|
|
generator.visit(node)
|
|
generator.result.append('\n')
|
|
# In some versions of Python, literals may appear as actual values. This
|
|
# ensures everything is string.
|
|
code = map(str, generator.result)
|
|
code = astor.source_repr.pretty_source(code).lstrip()
|
|
source_mapping = _build_source_map(original_node, code)
|
|
|
|
return code, source_mapping
|
|
|
|
|
|
def ast_to_object(node,
|
|
indentation=' ',
|
|
source_prefix=None,
|
|
delete_on_exit=True):
|
|
"""Return the Python objects represented by given AST.
|
|
|
|
Compiling the AST code this way ensures that the source code is readable by
|
|
e.g. `pdb` or `inspect`.
|
|
|
|
Args:
|
|
node: The code to compile, as an AST object.
|
|
indentation: The string to use for indentation.
|
|
source_prefix: Optional string to print as-is into the source file.
|
|
delete_on_exit: Whether to delete the temporary file used for compilation on
|
|
exit.
|
|
|
|
Returns:
|
|
A module object containing the compiled source code.
|
|
Raises:
|
|
ValueError: If ag_source_map__ is already in the namespace of the compiled
|
|
node.
|
|
"""
|
|
# code_source_mapping does not yet include the offsets from import statements.
|
|
source, code_source_mapping = ast_to_source(node, indentation=indentation)
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
|
# TODO(znado): move into an _offset_source_map() helper function.
|
|
# Need to offset the generated line numbers by the number of import lines.
|
|
if source_prefix:
|
|
num_import_lines = source_prefix.count('\n') + 1
|
|
else:
|
|
num_import_lines = 0
|
|
source_mapping = {}
|
|
for line_number, original_position in code_source_mapping.items():
|
|
source_map_key = origin_info.CodeLocation(
|
|
file_path=f.name, line_number=line_number + num_import_lines)
|
|
source_mapping[source_map_key] = original_position
|
|
module_name = os.path.basename(f.name[:-3])
|
|
if source_prefix:
|
|
f.write(source_prefix)
|
|
f.write('\n')
|
|
f.write(source)
|
|
if delete_on_exit:
|
|
atexit.register(lambda: os.remove(f.name))
|
|
compiled_node = imp.load_source(module_name, f.name)
|
|
|
|
# TODO(znado): Clean this up so we don't need to attach it to the namespace.
|
|
# TODO(znado): This does not work for classes because their methods share a
|
|
# namespace.
|
|
# This attaches the source map which is needed for error handling. Note that
|
|
# api.to_graph copies this source map into an attribute of the function.
|
|
#
|
|
# We need this so the ag_source_map__ variable is available to the call to
|
|
# rewrite_graph_construction_error in the except block inside each function
|
|
# that handles graph construction errors.
|
|
#
|
|
# We cannot get the rewritten function name until it is too late so templating
|
|
# is hard, and this cleanly fixes the
|
|
# issues encountered with nested functions because this is attached to the
|
|
# outermost one.
|
|
source_map_name = 'ag_source_map__'
|
|
if source_map_name in compiled_node.__dict__:
|
|
raise ValueError('cannot convert %s because is has namespace attribute '
|
|
'"%s", which is reserved for AutoGraph.' %
|
|
(compiled_node, source_map_name))
|
|
compiled_node.__dict__[source_map_name] = source_mapping
|
|
|
|
return compiled_node, source
|