laywerrobot/lib/python3.6/site-packages/tensorflow/contrib/autograph/pyct/compiler.py

159 lines
6 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converting AST to code.
Adapted from Tangent.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# TODO(mdan): Use six for compatibility here.
import atexit
import imp
import os
import tempfile
import astor
import gast
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.contrib.autograph.pyct import parser
def _build_source_map(node, code):
"""Return the Python objects represented by given AST.
Compiling the AST code this way ensures that the source code is readable by
e.g. `pdb` or `inspect`.
Args:
node: An AST node of the original generated code, before the source code is
generated.
code: The string representation of the source code for the newly generated
code.
Returns:
Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
generated code.
"""
# After we have the final generated code we reparse it to get the final line
# numbers. Then we walk through the generated and original ASTs in parallel
# to build the mapping between the user and generated code.
new_node = parser.parse_str(code)
origin_info.resolve(new_node, code)
source_mapping = {}
for before, after in ast_util.parallel_walk(node, new_node):
# Need both checks because if origin information is ever copied over to new
# nodes then we need to rely on the fact that only the original user code
# has the origin annotation.
if (anno.hasanno(before, anno.Basic.ORIGIN) and
anno.hasanno(after, anno.Basic.ORIGIN)):
source_info = anno.getanno(before, anno.Basic.ORIGIN)
new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
source_mapping[new_line_number] = source_info
return source_mapping
def ast_to_source(node, indentation=' '):
"""Return the source code of given AST."""
original_node = node
if isinstance(node, gast.AST):
node = gast.gast_to_ast(node)
generator = astor.codegen.SourceGenerator(indentation, False,
astor.string_repr.pretty_string)
generator.visit(node)
generator.result.append('\n')
# In some versions of Python, literals may appear as actual values. This
# ensures everything is string.
code = map(str, generator.result)
code = astor.source_repr.pretty_source(code).lstrip()
source_mapping = _build_source_map(original_node, code)
return code, source_mapping
def ast_to_object(node,
indentation=' ',
source_prefix=None,
delete_on_exit=True):
"""Return the Python objects represented by given AST.
Compiling the AST code this way ensures that the source code is readable by
e.g. `pdb` or `inspect`.
Args:
node: The code to compile, as an AST object.
indentation: The string to use for indentation.
source_prefix: Optional string to print as-is into the source file.
delete_on_exit: Whether to delete the temporary file used for compilation on
exit.
Returns:
A module object containing the compiled source code.
Raises:
ValueError: If ag_source_map__ is already in the namespace of the compiled
node.
"""
# code_source_mapping does not yet include the offsets from import statements.
source, code_source_mapping = ast_to_source(node, indentation=indentation)
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
# TODO(znado): move into an _offset_source_map() helper function.
# Need to offset the generated line numbers by the number of import lines.
if source_prefix:
num_import_lines = source_prefix.count('\n') + 1
else:
num_import_lines = 0
source_mapping = {}
for line_number, original_position in code_source_mapping.items():
source_map_key = origin_info.CodeLocation(
file_path=f.name, line_number=line_number + num_import_lines)
source_mapping[source_map_key] = original_position
module_name = os.path.basename(f.name[:-3])
if source_prefix:
f.write(source_prefix)
f.write('\n')
f.write(source)
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
compiled_node = imp.load_source(module_name, f.name)
# TODO(znado): Clean this up so we don't need to attach it to the namespace.
# TODO(znado): This does not work for classes because their methods share a
# namespace.
# This attaches the source map which is needed for error handling. Note that
# api.to_graph copies this source map into an attribute of the function.
#
# We need this so the ag_source_map__ variable is available to the call to
# rewrite_graph_construction_error in the except block inside each function
# that handles graph construction errors.
#
# We cannot get the rewritten function name until it is too late so templating
# is hard, and this cleanly fixes the
# issues encountered with nested functions because this is attached to the
# outermost one.
source_map_name = 'ag_source_map__'
if source_map_name in compiled_node.__dict__:
raise ValueError('cannot convert %s because is has namespace attribute '
'"%s", which is reserved for AutoGraph.' %
(compiled_node, source_map_name))
compiled_node.__dict__[source_map_name] = source_mapping
return compiled_node, source