# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """AST node annotation support. Adapted from Tangent. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import enum # pylint:disable=g-bad-import-order import gast # pylint:enable=g-bad-import-order # TODO(mdan): Shorten the names. # These names are heavily used, and anno.blaa # TODO(mdan): Replace the attr-dict mechanism with a more typed solution. class NoValue(enum.Enum): def __repr__(self): return self.name class Basic(NoValue): """Container for basic annotation keys. The enum values are used strictly for documentation purposes. """ QN = 'Qualified name, as it appeared in the code. See qual_names.py.' SKIP_PROCESSING = ( 'This node should be preserved as is and not processed any further.') INDENT_BLOCK_REMAINDER = ( 'When a node is annotated with this, the remainder of the block should' ' be indented below it. The annotation contains a tuple' ' (new_body, name_map), where `new_body` is the new indented block and' ' `name_map` allows renaming symbols.') ORIGIN = ('Information about the source code that converted code originated' ' from. See origin_information.py.') class Static(NoValue): """Container for static analysis annotation keys. The enum values are used strictly for documentation purposes. """ # Deprecated - use reaching definitions instead. # Symbols # These flags are boolean. IS_LOCAL = 'Symbol is local to the function scope being analyzed.' IS_PARAM = 'Symbol is a parameter to the function being analyzed.' # Scopes # Scopes are represented by objects of type activity.Scope. SCOPE = 'The scope for the annotated node. See activity.py.' # TODO(mdan): Drop these in favor of accessing the child's SCOPE. ARGS_SCOPE = 'The scope for the argument list of a function call.' COND_SCOPE = 'The scope for the test node of a conditional statement.' BODY_SCOPE = ( 'The scope for the main body of a statement (True branch for if ' 'statements, main body for loops).') ORELSE_SCOPE = ( 'The scope for the orelse body of a statement (False branch for if ' 'statements, orelse body for loops).') # Static analysis annotations. DEFINITIONS = ( 'Reaching definition information. See reaching_definitions.py.') ORIG_DEFINITIONS = ( 'The value of DEFINITIONS that applied to the original code before any' ' conversion.') DEFINED_VARS_IN = ( 'Symbols defined when entering the node. See reaching_definitions.py.') LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.') FAIL = object() def keys(node, field_name='___pyct_anno'): if not hasattr(node, field_name): return frozenset() return frozenset(getattr(node, field_name).keys()) def getanno(node, key, default=FAIL, field_name='___pyct_anno'): if (default is FAIL or (hasattr(node, field_name) and (key in getattr(node, field_name)))): return getattr(node, field_name)[key] else: return default def hasanno(node, key, field_name='___pyct_anno'): return hasattr(node, field_name) and key in getattr(node, field_name) def setanno(node, key, value, field_name='___pyct_anno'): annotations = getattr(node, field_name, {}) setattr(node, field_name, annotations) annotations[key] = value # So that the annotations survive gast_to_ast() and ast_to_gast() if field_name not in node._fields: node._fields += (field_name,) def delanno(node, key, field_name='___pyct_anno'): annotations = getattr(node, field_name) del annotations[key] if not annotations: delattr(node, field_name) node._fields = tuple(f for f in node._fields if f != field_name) def copyanno(from_node, to_node, key, field_name='___pyct_anno'): if hasanno(from_node, key, field_name=field_name): setanno( to_node, key, getanno(from_node, key, field_name=field_name), field_name=field_name) def dup(node, copy_map, field_name='___pyct_anno'): """Recursively copies annotations in an AST tree. Args: node: ast.AST copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination key. All annotations with the source key will be copied to identical annotations with the destination key. field_name: str """ for n in gast.walk(node): for k in copy_map: if hasanno(n, k, field_name): setanno(n, copy_map[k], getanno(n, k, field_name), field_name)