EVOLUTION-MANAGER
Edit File: parser.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Converting code to AST. Adapted from Tangent. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import inspect import linecache import re import sys import textwrap import tokenize import astunparse import gast import six from tensorflow.python.autograph.pyct import errors from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.util import tf_inspect PY2_PREAMBLE = textwrap.dedent(""" from __future__ import division from __future__ import print_function """) PY3_PREAMBLE = '' MAX_SIZE = 0 if sys.version_info >= (3,): STANDARD_PREAMBLE = PY3_PREAMBLE MAX_SIZE = sys.maxsize else: STANDARD_PREAMBLE = PY2_PREAMBLE MAX_SIZE = sys.maxint STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__') _LEADING_WHITESPACE = re.compile(r'\s*') def _unfold_continuations(code_string): """Removes any backslash line continuations from the code.""" return code_string.replace('\\\n', '') def dedent_block(code_string): """Dedents a code so that its first line starts at row zero.""" code_string = _unfold_continuations(code_string) token_gen = tokenize.generate_tokens(six.StringIO(code_string).readline) block_indentation = None tokens = [] try: for tok in token_gen: tokens.append(tok) except tokenize.TokenError: # Resolution of lambda functions may yield incomplete code, which can # in turn generate this error. We silently ignore this error because the # parser may still be able to deal with it. pass for tok in tokens: tok_type, tok_string, _, _, _ = tok if tok_type == tokenize.INDENT: block_indentation = tok_string block_level = len(block_indentation) break elif tok_type not in ( tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT): block_indentation = '' break if not block_indentation: return code_string block_level = len(block_indentation) first_indent_uses_tabs = '\t' in block_indentation for i, tok in enumerate(tokens): tok_type, tok_string, _, _, _ = tok if tok_type == tokenize.INDENT: if ((' ' in tok_string and first_indent_uses_tabs) or ('\t' in tok_string and not first_indent_uses_tabs)): # TODO(mdan): We could attempt to convert tabs to spaces by unix rule. # See: # https://docs.python.org/3/reference/lexical_analysis.html#indentation raise errors.UnsupportedLanguageElementError( 'code mixing tabs and spaces for indentation is not allowed') if len(tok_string) >= block_level: tok_string = tok_string[block_level:] tokens[i] = (tok_type, tok_string) new_code = tokenize.untokenize(tokens) # Note: untokenize respects the line structure, but not the whitespace within # lines. For example, `def foo()` may be untokenized as `def foo ()` # So instead of using the output of dedent, we match the leading whitespace # on each line. dedented_code = [] for line, new_line in zip(code_string.split('\n'), new_code.split('\n')): original_indent = re.match(_LEADING_WHITESPACE, line).group() new_indent = re.match(_LEADING_WHITESPACE, new_line).group() if len(original_indent) > len(new_indent): dedented_line = line[len(original_indent) - len(new_indent):] else: dedented_line = line dedented_code.append(dedented_line) new_code = '\n'.join(dedented_code) return new_code def parse_entity(entity, future_features): """Returns the AST and source code of given entity. Args: entity: Any, Python function/method/class future_features: Iterable[Text], future features to use (e.g. 'print_statement'). See https://docs.python.org/2/reference/simple_stmts.html#future Returns: gast.AST, Text: the parsed AST node; the source code that was parsed to generate the AST (including any prefixes that this function may have added). """ if inspect_utils.islambda(entity): return _parse_lambda(entity) try: original_source = inspect_utils.getimmediatesource(entity) except (IOError, OSError) as e: raise ValueError( 'Unable to locate the source code of {}. Note that functions defined' ' in certain environments, like the interactive Python shell do not' ' expose their source code. If that is the case, you should to define' ' them in a .py source file. If you are certain the code is' ' graph-compatible, wrap the call using' ' @tf.autograph.do_not_convert. Original error: {}'.format(entity, e)) source = dedent_block(original_source) future_statements = tuple( 'from __future__ import {}'.format(name) for name in future_features) source = '\n'.join(future_statements + (source,)) return parse(source, preamble_len=len(future_features)), source def _without_context(node, lines, minl, maxl): """Returns a clean node and source code without indenting and context.""" for n in gast.walk(node): lineno = getattr(n, 'lineno', None) if lineno is not None: n.lineno = lineno - minl end_lineno = getattr(n, 'end_lineno', None) if end_lineno is not None: n.end_lineno = end_lineno - minl code_lines = lines[minl - 1:maxl] # Attempt to clean up surrounding context code. end_col_offset = getattr(node, 'end_col_offset', None) if end_col_offset is not None: # This is only available in 3.8. code_lines[-1] = code_lines[-1][:end_col_offset] col_offset = getattr(node, 'col_offset', None) if col_offset is None: # Older Python: try to find the "lambda" token. This is brittle. match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0]) if match is not None: col_offset = match.start(0) if col_offset is not None: code_lines[0] = code_lines[0][col_offset:] code_block = '\n'.join([c.rstrip() for c in code_lines]) return node, code_block def _arg_name(node): if node is None: return None if isinstance(node, gast.Name): return node.id assert isinstance(node, str) return node def _node_matches_argspec(node, func): """Returns True is node fits the argspec of func.""" # TODO(mdan): Use just inspect once support for Python 2 is dropped. arg_spec = tf_inspect.getfullargspec(func) node_args = tuple(_arg_name(arg) for arg in node.args.args) if node_args != tuple(arg_spec.args): return False if arg_spec.varargs != _arg_name(node.args.vararg): return False if arg_spec.varkw != _arg_name(node.args.kwarg): return False node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs) if node_kwonlyargs != tuple(arg_spec.kwonlyargs): return False return True def _parse_lambda(lam): """Returns the AST and source code of given lambda function. Args: lam: types.LambdaType, Python function/method/class Returns: gast.AST, Text: the parsed AST node; the source code that was parsed to generate the AST (including any prefixes that this function may have added). """ # TODO(mdan): Use a fast path if the definition is not multi-line. # We could detect that the lambda is in a multi-line expression by looking # at the surrounding code - an surrounding set of parentheses indicates a # potential multi-line definition. mod = inspect.getmodule(lam) f = inspect.getsourcefile(lam) def_line = lam.__code__.co_firstlineno # This method is more robust that just calling inspect.getsource(mod), as it # works in interactive shells, where getsource would fail. This is the # same procedure followed by inspect for non-modules: # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772 lines = linecache.getlines(f, mod.__dict__) source = ''.join(lines) # Narrow down to the last node starting before our definition node. all_nodes = parse(source, preamble_len=0, single_node=False) search_nodes = [] for node in all_nodes: # Also include nodes without a line number, for safety. This is defensive - # we don't know whether such nodes might exist, and if they do, whether # they are not safe to skip. # TODO(mdan): Replace this check with an assertion or skip such nodes. if getattr(node, 'lineno', def_line) <= def_line: search_nodes.append(node) else: # Found a node starting past our lambda - can stop the search. break # Extract all lambda nodes from the shortlist. lambda_nodes = [] for node in search_nodes: lambda_nodes.extend( n for n in gast.walk(node) if isinstance(n, gast.Lambda)) # Filter down to lambda nodes which span our actual lambda. candidates = [] for ln in lambda_nodes: minl, maxl = MAX_SIZE, 0 for n in gast.walk(ln): minl = min(minl, getattr(n, 'lineno', minl)) lineno = getattr(n, 'lineno', maxl) end_lineno = getattr(n, 'end_lineno', None) if end_lineno is not None: # end_lineno is more precise, but lineno should almost always work too. lineno = end_lineno maxl = max(maxl, lineno) if minl <= def_line <= maxl: candidates.append((ln, minl, maxl)) # Happy path: exactly one node found. if len(candidates) == 1: (node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking return _without_context(node, lines, minl, maxl) elif not candidates: raise errors.UnsupportedLanguageElementError( 'could not parse the source code of {}:' ' no matching AST found'.format(lam)) # Attempt to narrow down selection by signature is multiple nodes are found. matches = [v for v in candidates if _node_matches_argspec(v[0], lam)] if len(matches) == 1: (node, minl, maxl), = matches return _without_context(node, lines, minl, maxl) # Give up if could not narrow down to a single node. matches = '\n'.join( 'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False)) for i, (node, _, _) in enumerate(matches)) raise errors.UnsupportedLanguageElementError( 'could not parse the source code of {}: found multiple definitions with' ' identical signatures at the location. This error' ' may be avoided by defining each lambda on a single line and with' ' unique argument names.\n{}'.format(lam, matches)) # TODO(mdan): This should take futures as input instead. def parse(src, preamble_len=0, single_node=True): """Returns the AST of given piece of code. Args: src: Text preamble_len: Int, indicates leading nodes in the parsed AST which should be dropped. single_node: Bool, whether `src` is assumed to be represented by exactly one AST node. Returns: ast.AST """ module_node = gast.parse(src) nodes = module_node.body if preamble_len: nodes = nodes[preamble_len:] if single_node: if len(nodes) != 1: raise ValueError('expected exactly one node node, found {}'.format(nodes)) return nodes[0] return nodes def parse_expression(src): """Returns the AST of given identifier. Args: src: A piece of code that represents a single Python expression Returns: A gast.AST object. Raises: ValueError: if src does not consist of a single Expression. """ src = STANDARD_PREAMBLE + src.strip() node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) if __debug__: if not isinstance(node, gast.Expr): raise ValueError( 'expected a single expression, found instead {}'.format(node)) return node.value def unparse(node, indentation=None, include_encoding_marker=True): """Returns the source code of given AST. Args: node: The code to compile, as an AST object. indentation: Unused, deprecated. The returning code will always be indented at 4 spaces. include_encoding_marker: Bool, thether to include a comment on the first line to explicitly specify UTF-8 encoding. Returns: code: The source code generated from the AST object source_mapping: A mapping between the user and AutoGraph generated code. """ del indentation # astunparse doesn't allow configuring it. if not isinstance(node, (list, tuple)): node = (node,) codes = [] if include_encoding_marker: codes.append('# coding=utf-8') for n in node: if isinstance(n, gast.AST): n = gast.gast_to_ast(n) codes.append(astunparse.unparse(n).strip()) return '\n'.join(codes)