Spaces:
Paused
Paused
| """ | |
| MIT License | |
| Copyright (c) 2021 Alex Hall | |
| Permission is hereby granted, free of charge, to any person obtaining a copy | |
| of this software and associated documentation files (the "Software"), to deal | |
| in the Software without restriction, including without limitation the rights | |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| copies of the Software, and to permit persons to whom the Software is | |
| furnished to do so, subject to the following conditions: | |
| The above copyright notice and this permission notice shall be included in all | |
| copies or substantial portions of the Software. | |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| SOFTWARE. | |
| """ | |
| import __future__ | |
| import ast | |
| import dis | |
| import inspect | |
| import io | |
| import linecache | |
| import re | |
| import sys | |
| import types | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| from functools import lru_cache | |
| from itertools import islice | |
| from itertools import zip_longest | |
| from operator import attrgetter | |
| from pathlib import Path | |
| from threading import RLock | |
| from tokenize import detect_encoding | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, \ | |
| Type, TypeVar, Union, cast | |
| if TYPE_CHECKING: # pragma: no cover | |
| from asttokens import ASTTokens, ASTText | |
| from asttokens.asttokens import ASTTextBase | |
| function_node_types = (ast.FunctionDef, ast.AsyncFunctionDef) # type: Tuple[Type, ...] | |
| cache = lru_cache(maxsize=None) | |
| # Type class used to expand out the definition of AST to include fields added by this library | |
| # It's not actually used for anything other than type checking though! | |
| class EnhancedAST(ast.AST): | |
| parent = None # type: EnhancedAST | |
| class Instruction(dis.Instruction): | |
| lineno = None # type: int | |
| # Type class used to expand out the definition of AST to include fields added by this library | |
| # It's not actually used for anything other than type checking though! | |
| class EnhancedInstruction(Instruction): | |
| _copied = None # type: bool | |
| def assert_(condition, message=""): | |
| # type: (Any, str) -> None | |
| """ | |
| Like an assert statement, but unaffected by -O | |
| :param condition: value that is expected to be truthy | |
| :type message: Any | |
| """ | |
| if not condition: | |
| raise AssertionError(str(message)) | |
| def get_instructions(co): | |
| # type: (types.CodeType) -> Iterator[EnhancedInstruction] | |
| lineno = co.co_firstlineno | |
| for inst in dis.get_instructions(co): | |
| inst = cast(EnhancedInstruction, inst) | |
| lineno = inst.starts_line or lineno | |
| assert_(lineno) | |
| inst.lineno = lineno | |
| yield inst | |
| TESTING = 0 | |
| class NotOneValueFound(Exception): | |
| def __init__(self,msg,values=[]): | |
| # type: (str, Sequence) -> None | |
| self.values=values | |
| super(NotOneValueFound,self).__init__(msg) | |
| T = TypeVar('T') | |
| def only(it): | |
| # type: (Iterable[T]) -> T | |
| if isinstance(it, Sized): | |
| if len(it) != 1: | |
| raise NotOneValueFound('Expected one value, found %s' % len(it)) | |
| # noinspection PyTypeChecker | |
| return list(it)[0] | |
| lst = tuple(islice(it, 2)) | |
| if len(lst) == 0: | |
| raise NotOneValueFound('Expected one value, found 0') | |
| if len(lst) > 1: | |
| raise NotOneValueFound('Expected one value, found several',lst) | |
| return lst[0] | |
| class Source(object): | |
| """ | |
| The source code of a single file and associated metadata. | |
| The main method of interest is the classmethod `executing(frame)`. | |
| If you want an instance of this class, don't construct it. | |
| Ideally use the classmethod `for_frame(frame)`. | |
| If you don't have a frame, use `for_filename(filename [, module_globals])`. | |
| These methods cache instances by filename, so at most one instance exists per filename. | |
| Attributes: | |
| - filename | |
| - text | |
| - lines | |
| - tree: AST parsed from text, or None if text is not valid Python | |
| All nodes in the tree have an extra `parent` attribute | |
| Other methods of interest: | |
| - statements_at_line | |
| - asttokens | |
| - code_qualname | |
| """ | |
| def __init__(self, filename, lines): | |
| # type: (str, Sequence[str]) -> None | |
| """ | |
| Don't call this constructor, see the class docstring. | |
| """ | |
| self.filename = filename | |
| self.text = ''.join(lines) | |
| self.lines = [line.rstrip('\r\n') for line in lines] | |
| self._nodes_by_line = defaultdict(list) | |
| self.tree = None | |
| self._qualnames = {} | |
| self._asttokens = None # type: Optional[ASTTokens] | |
| self._asttext = None # type: Optional[ASTText] | |
| try: | |
| self.tree = ast.parse(self.text, filename=filename) | |
| except (SyntaxError, ValueError): | |
| pass | |
| else: | |
| for node in ast.walk(self.tree): | |
| for child in ast.iter_child_nodes(node): | |
| cast(EnhancedAST, child).parent = cast(EnhancedAST, node) | |
| for lineno in node_linenos(node): | |
| self._nodes_by_line[lineno].append(node) | |
| visitor = QualnameVisitor() | |
| visitor.visit(self.tree) | |
| self._qualnames = visitor.qualnames | |
| def for_frame(cls, frame, use_cache=True): | |
| # type: (types.FrameType, bool) -> "Source" | |
| """ | |
| Returns the `Source` object corresponding to the file the frame is executing in. | |
| """ | |
| return cls.for_filename(frame.f_code.co_filename, frame.f_globals or {}, use_cache) | |
| def for_filename( | |
| cls, | |
| filename, | |
| module_globals=None, | |
| use_cache=True, # noqa no longer used | |
| ): | |
| # type: (Union[str, Path], Optional[Dict[str, Any]], bool) -> "Source" | |
| if isinstance(filename, Path): | |
| filename = str(filename) | |
| def get_lines(): | |
| # type: () -> List[str] | |
| return linecache.getlines(cast(str, filename), module_globals) | |
| # Save the current linecache entry, then ensure the cache is up to date. | |
| entry = linecache.cache.get(filename) # type: ignore[attr-defined] | |
| linecache.checkcache(filename) | |
| lines = get_lines() | |
| if entry is not None and not lines: | |
| # There was an entry, checkcache removed it, and nothing replaced it. | |
| # This means the file wasn't simply changed (because the `lines` wouldn't be empty) | |
| # but rather the file was found not to exist, probably because `filename` was fake. | |
| # Restore the original entry so that we still have something. | |
| linecache.cache[filename] = entry # type: ignore[attr-defined] | |
| lines = get_lines() | |
| return cls._for_filename_and_lines(filename, tuple(lines)) | |
| def _for_filename_and_lines(cls, filename, lines): | |
| # type: (str, Sequence[str]) -> "Source" | |
| source_cache = cls._class_local('__source_cache_with_lines', {}) # type: Dict[Tuple[str, Sequence[str]], Source] | |
| try: | |
| return source_cache[(filename, lines)] | |
| except KeyError: | |
| pass | |
| result = source_cache[(filename, lines)] = cls(filename, lines) | |
| return result | |
| def lazycache(cls, frame): | |
| # type: (types.FrameType) -> None | |
| linecache.lazycache(frame.f_code.co_filename, frame.f_globals) | |
| def executing(cls, frame_or_tb): | |
| # type: (Union[types.TracebackType, types.FrameType]) -> "Executing" | |
| """ | |
| Returns an `Executing` object representing the operation | |
| currently executing in the given frame or traceback object. | |
| """ | |
| if isinstance(frame_or_tb, types.TracebackType): | |
| # https://docs.python.org/3/reference/datamodel.html#traceback-objects | |
| # "tb_lineno gives the line number where the exception occurred; | |
| # tb_lasti indicates the precise instruction. | |
| # The line number and last instruction in the traceback may differ | |
| # from the line number of its frame object | |
| # if the exception occurred in a try statement with no matching except clause | |
| # or with a finally clause." | |
| tb = frame_or_tb | |
| frame = tb.tb_frame | |
| lineno = tb.tb_lineno | |
| lasti = tb.tb_lasti | |
| else: | |
| frame = frame_or_tb | |
| lineno = frame.f_lineno | |
| lasti = frame.f_lasti | |
| code = frame.f_code | |
| key = (code, id(code), lasti) | |
| executing_cache = cls._class_local('__executing_cache', {}) # type: Dict[Tuple[types.CodeType, int, int], Any] | |
| args = executing_cache.get(key) | |
| if not args: | |
| node = stmts = decorator = None | |
| source = cls.for_frame(frame) | |
| tree = source.tree | |
| if tree: | |
| try: | |
| stmts = source.statements_at_line(lineno) | |
| if stmts: | |
| if is_ipython_cell_code(code): | |
| decorator, node = find_node_ipython(frame, lasti, stmts, source) | |
| else: | |
| node_finder = NodeFinder(frame, stmts, tree, lasti, source) | |
| node = node_finder.result | |
| decorator = node_finder.decorator | |
| if node: | |
| new_stmts = {statement_containing_node(node)} | |
| assert_(new_stmts <= stmts) | |
| stmts = new_stmts | |
| except Exception: | |
| if TESTING: | |
| raise | |
| executing_cache[key] = args = source, node, stmts, decorator | |
| return Executing(frame, *args) | |
| def _class_local(cls, name, default): | |
| # type: (str, T) -> T | |
| """ | |
| Returns an attribute directly associated with this class | |
| (as opposed to subclasses), setting default if necessary | |
| """ | |
| # classes have a mappingproxy preventing us from using setdefault | |
| result = cls.__dict__.get(name, default) | |
| setattr(cls, name, result) | |
| return result | |
| def statements_at_line(self, lineno): | |
| # type: (int) -> Set[EnhancedAST] | |
| """ | |
| Returns the statement nodes overlapping the given line. | |
| Returns at most one statement unless semicolons are present. | |
| If the `text` attribute is not valid python, meaning | |
| `tree` is None, returns an empty set. | |
| Otherwise, `Source.for_frame(frame).statements_at_line(frame.f_lineno)` | |
| should return at least one statement. | |
| """ | |
| return { | |
| statement_containing_node(node) | |
| for node in | |
| self._nodes_by_line[lineno] | |
| } | |
| def asttext(self): | |
| # type: () -> ASTText | |
| """ | |
| Returns an ASTText object for getting the source of specific AST nodes. | |
| See http://asttokens.readthedocs.io/en/latest/api-index.html | |
| """ | |
| from asttokens import ASTText # must be installed separately | |
| if self._asttext is None: | |
| self._asttext = ASTText(self.text, tree=self.tree, filename=self.filename) | |
| return self._asttext | |
| def asttokens(self): | |
| # type: () -> ASTTokens | |
| """ | |
| Returns an ASTTokens object for getting the source of specific AST nodes. | |
| See http://asttokens.readthedocs.io/en/latest/api-index.html | |
| """ | |
| import asttokens # must be installed separately | |
| if self._asttokens is None: | |
| if hasattr(asttokens, 'ASTText'): | |
| self._asttokens = self.asttext().asttokens | |
| else: # pragma: no cover | |
| self._asttokens = asttokens.ASTTokens(self.text, tree=self.tree, filename=self.filename) | |
| return self._asttokens | |
| def _asttext_base(self): | |
| # type: () -> ASTTextBase | |
| import asttokens # must be installed separately | |
| if hasattr(asttokens, 'ASTText'): | |
| return self.asttext() | |
| else: # pragma: no cover | |
| return self.asttokens() | |
| def decode_source(source): | |
| # type: (Union[str, bytes]) -> str | |
| if isinstance(source, bytes): | |
| encoding = Source.detect_encoding(source) | |
| return source.decode(encoding) | |
| else: | |
| return source | |
| def detect_encoding(source): | |
| # type: (bytes) -> str | |
| return detect_encoding(io.BytesIO(source).readline)[0] | |
| def code_qualname(self, code): | |
| # type: (types.CodeType) -> str | |
| """ | |
| Imitates the __qualname__ attribute of functions for code objects. | |
| Given: | |
| - A function `func` | |
| - A frame `frame` for an execution of `func`, meaning: | |
| `frame.f_code is func.__code__` | |
| `Source.for_frame(frame).code_qualname(frame.f_code)` | |
| will be equal to `func.__qualname__`*. Works for Python 2 as well, | |
| where of course no `__qualname__` attribute exists. | |
| Falls back to `code.co_name` if there is no appropriate qualname. | |
| Based on https://github.com/wbolster/qualname | |
| (* unless `func` is a lambda | |
| nested inside another lambda on the same line, in which case | |
| the outer lambda's qualname will be returned for the codes | |
| of both lambdas) | |
| """ | |
| assert_(code.co_filename == self.filename) | |
| return self._qualnames.get((code.co_name, code.co_firstlineno), code.co_name) | |
| class Executing(object): | |
| """ | |
| Information about the operation a frame is currently executing. | |
| Generally you will just want `node`, which is the AST node being executed, | |
| or None if it's unknown. | |
| If a decorator is currently being called, then: | |
| - `node` is a function or class definition | |
| - `decorator` is the expression in `node.decorator_list` being called | |
| - `statements == {node}` | |
| """ | |
| def __init__(self, frame, source, node, stmts, decorator): | |
| # type: (types.FrameType, Source, EnhancedAST, Set[ast.stmt], Optional[EnhancedAST]) -> None | |
| self.frame = frame | |
| self.source = source | |
| self.node = node | |
| self.statements = stmts | |
| self.decorator = decorator | |
| def code_qualname(self): | |
| # type: () -> str | |
| return self.source.code_qualname(self.frame.f_code) | |
| def text(self): | |
| # type: () -> str | |
| return self.source._asttext_base().get_text(self.node) | |
| def text_range(self): | |
| # type: () -> Tuple[int, int] | |
| return self.source._asttext_base().get_text_range(self.node) | |
| class QualnameVisitor(ast.NodeVisitor): | |
| def __init__(self): | |
| # type: () -> None | |
| super(QualnameVisitor, self).__init__() | |
| self.stack = [] # type: List[str] | |
| self.qualnames = {} # type: Dict[Tuple[str, int], str] | |
| def add_qualname(self, node, name=None): | |
| # type: (ast.AST, Optional[str]) -> None | |
| name = name or node.name # type: ignore[attr-defined] | |
| self.stack.append(name) | |
| if getattr(node, 'decorator_list', ()): | |
| lineno = node.decorator_list[0].lineno # type: ignore[attr-defined] | |
| else: | |
| lineno = node.lineno # type: ignore[attr-defined] | |
| self.qualnames.setdefault((name, lineno), ".".join(self.stack)) | |
| def visit_FunctionDef(self, node, name=None): | |
| # type: (ast.AST, Optional[str]) -> None | |
| assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)), node | |
| self.add_qualname(node, name) | |
| self.stack.append('<locals>') | |
| children = [] # type: Sequence[ast.AST] | |
| if isinstance(node, ast.Lambda): | |
| children = [node.body] | |
| else: | |
| children = node.body | |
| for child in children: | |
| self.visit(child) | |
| self.stack.pop() | |
| self.stack.pop() | |
| # Find lambdas in the function definition outside the body, | |
| # e.g. decorators or default arguments | |
| # Based on iter_child_nodes | |
| for field, child in ast.iter_fields(node): | |
| if field == 'body': | |
| continue | |
| if isinstance(child, ast.AST): | |
| self.visit(child) | |
| elif isinstance(child, list): | |
| for grandchild in child: | |
| if isinstance(grandchild, ast.AST): | |
| self.visit(grandchild) | |
| visit_AsyncFunctionDef = visit_FunctionDef | |
| def visit_Lambda(self, node): | |
| # type: (ast.AST) -> None | |
| assert isinstance(node, ast.Lambda) | |
| self.visit_FunctionDef(node, '<lambda>') | |
| def visit_ClassDef(self, node): | |
| # type: (ast.AST) -> None | |
| assert isinstance(node, ast.ClassDef) | |
| self.add_qualname(node) | |
| self.generic_visit(node) | |
| self.stack.pop() | |
| future_flags = sum( | |
| getattr(__future__, fname).compiler_flag for fname in __future__.all_feature_names | |
| ) | |
| def compile_similar_to(source, matching_code): | |
| # type: (ast.Module, types.CodeType) -> Any | |
| return compile( | |
| source, | |
| matching_code.co_filename, | |
| 'exec', | |
| flags=future_flags & matching_code.co_flags, | |
| dont_inherit=True, | |
| ) | |
| sentinel = 'io8urthglkjdghvljusketgIYRFYUVGHFRTBGVHKGF78678957647698' | |
| def is_rewritten_by_pytest(code): | |
| # type: (types.CodeType) -> bool | |
| return any( | |
| bc.opname != "LOAD_CONST" and isinstance(bc.argval,str) and bc.argval.startswith("@py") | |
| for bc in get_instructions(code) | |
| ) | |
| class SentinelNodeFinder(object): | |
| result = None # type: EnhancedAST | |
| def __init__(self, frame, stmts, tree, lasti, source): | |
| # type: (types.FrameType, Set[EnhancedAST], ast.Module, int, Source) -> None | |
| assert_(stmts) | |
| self.frame = frame | |
| self.tree = tree | |
| self.code = code = frame.f_code | |
| self.is_pytest = is_rewritten_by_pytest(code) | |
| if self.is_pytest: | |
| self.ignore_linenos = frozenset(assert_linenos(tree)) | |
| else: | |
| self.ignore_linenos = frozenset() | |
| self.decorator = None | |
| self.instruction = instruction = self.get_actual_current_instruction(lasti) | |
| op_name = instruction.opname | |
| extra_filter = lambda e: True # type: Callable[[Any], bool] | |
| ctx = type(None) # type: Type | |
| typ = type(None) # type: Type | |
| if op_name.startswith('CALL_'): | |
| typ = ast.Call | |
| elif op_name.startswith(('BINARY_SUBSCR', 'SLICE+')): | |
| typ = ast.Subscript | |
| ctx = ast.Load | |
| elif op_name.startswith('BINARY_'): | |
| typ = ast.BinOp | |
| op_type = dict( | |
| BINARY_POWER=ast.Pow, | |
| BINARY_MULTIPLY=ast.Mult, | |
| BINARY_MATRIX_MULTIPLY=getattr(ast, "MatMult", ()), | |
| BINARY_FLOOR_DIVIDE=ast.FloorDiv, | |
| BINARY_TRUE_DIVIDE=ast.Div, | |
| BINARY_MODULO=ast.Mod, | |
| BINARY_ADD=ast.Add, | |
| BINARY_SUBTRACT=ast.Sub, | |
| BINARY_LSHIFT=ast.LShift, | |
| BINARY_RSHIFT=ast.RShift, | |
| BINARY_AND=ast.BitAnd, | |
| BINARY_XOR=ast.BitXor, | |
| BINARY_OR=ast.BitOr, | |
| )[op_name] | |
| extra_filter = lambda e: isinstance(e.op, op_type) | |
| elif op_name.startswith('UNARY_'): | |
| typ = ast.UnaryOp | |
| op_type = dict( | |
| UNARY_POSITIVE=ast.UAdd, | |
| UNARY_NEGATIVE=ast.USub, | |
| UNARY_NOT=ast.Not, | |
| UNARY_INVERT=ast.Invert, | |
| )[op_name] | |
| extra_filter = lambda e: isinstance(e.op, op_type) | |
| elif op_name in ('LOAD_ATTR', 'LOAD_METHOD', 'LOOKUP_METHOD'): | |
| typ = ast.Attribute | |
| ctx = ast.Load | |
| extra_filter = lambda e: attr_names_match(e.attr, instruction.argval) | |
| elif op_name in ('LOAD_NAME', 'LOAD_GLOBAL', 'LOAD_FAST', 'LOAD_DEREF', 'LOAD_CLASSDEREF'): | |
| typ = ast.Name | |
| ctx = ast.Load | |
| extra_filter = lambda e: e.id == instruction.argval | |
| elif op_name in ('COMPARE_OP', 'IS_OP', 'CONTAINS_OP'): | |
| typ = ast.Compare | |
| extra_filter = lambda e: len(e.ops) == 1 | |
| elif op_name.startswith(('STORE_SLICE', 'STORE_SUBSCR')): | |
| ctx = ast.Store | |
| typ = ast.Subscript | |
| elif op_name.startswith('STORE_ATTR'): | |
| ctx = ast.Store | |
| typ = ast.Attribute | |
| extra_filter = lambda e: attr_names_match(e.attr, instruction.argval) | |
| else: | |
| raise RuntimeError(op_name) | |
| with lock: | |
| exprs = { | |
| cast(EnhancedAST, node) | |
| for stmt in stmts | |
| for node in ast.walk(stmt) | |
| if isinstance(node, typ) | |
| if isinstance(getattr(node, "ctx", None), ctx) | |
| if extra_filter(node) | |
| if statement_containing_node(node) == stmt | |
| } | |
| if ctx == ast.Store: | |
| # No special bytecode tricks here. | |
| # We can handle multiple assigned attributes with different names, | |
| # but only one assigned subscript. | |
| self.result = only(exprs) | |
| return | |
| matching = list(self.matching_nodes(exprs)) | |
| if not matching and typ == ast.Call: | |
| self.find_decorator(stmts) | |
| else: | |
| self.result = only(matching) | |
| def find_decorator(self, stmts): | |
| # type: (Union[List[EnhancedAST], Set[EnhancedAST]]) -> None | |
| stmt = only(stmts) | |
| assert_(isinstance(stmt, (ast.ClassDef, function_node_types))) | |
| decorators = stmt.decorator_list # type: ignore[attr-defined] | |
| assert_(decorators) | |
| line_instructions = [ | |
| inst | |
| for inst in self.clean_instructions(self.code) | |
| if inst.lineno == self.frame.f_lineno | |
| ] | |
| last_decorator_instruction_index = [ | |
| i | |
| for i, inst in enumerate(line_instructions) | |
| if inst.opname == "CALL_FUNCTION" | |
| ][-1] | |
| assert_( | |
| line_instructions[last_decorator_instruction_index + 1].opname.startswith( | |
| "STORE_" | |
| ) | |
| ) | |
| decorator_instructions = line_instructions[ | |
| last_decorator_instruction_index | |
| - len(decorators) | |
| + 1 : last_decorator_instruction_index | |
| + 1 | |
| ] | |
| assert_({inst.opname for inst in decorator_instructions} == {"CALL_FUNCTION"}) | |
| decorator_index = decorator_instructions.index(self.instruction) | |
| decorator = decorators[::-1][decorator_index] | |
| self.decorator = decorator | |
| self.result = stmt | |
| def clean_instructions(self, code): | |
| # type: (types.CodeType) -> List[EnhancedInstruction] | |
| return [ | |
| inst | |
| for inst in get_instructions(code) | |
| if inst.opname not in ("EXTENDED_ARG", "NOP") | |
| if inst.lineno not in self.ignore_linenos | |
| ] | |
| def get_original_clean_instructions(self): | |
| # type: () -> List[EnhancedInstruction] | |
| result = self.clean_instructions(self.code) | |
| # pypy sometimes (when is not clear) | |
| # inserts JUMP_IF_NOT_DEBUG instructions in bytecode | |
| # If they're not present in our compiled instructions, | |
| # ignore them in the original bytecode | |
| if not any( | |
| inst.opname == "JUMP_IF_NOT_DEBUG" | |
| for inst in self.compile_instructions() | |
| ): | |
| result = [ | |
| inst for inst in result | |
| if inst.opname != "JUMP_IF_NOT_DEBUG" | |
| ] | |
| return result | |
| def matching_nodes(self, exprs): | |
| # type: (Set[EnhancedAST]) -> Iterator[EnhancedAST] | |
| original_instructions = self.get_original_clean_instructions() | |
| original_index = only( | |
| i | |
| for i, inst in enumerate(original_instructions) | |
| if inst == self.instruction | |
| ) | |
| for expr_index, expr in enumerate(exprs): | |
| setter = get_setter(expr) | |
| assert setter is not None | |
| # noinspection PyArgumentList | |
| replacement = ast.BinOp( | |
| left=expr, | |
| op=ast.Pow(), | |
| right=ast.Str(s=sentinel), | |
| ) | |
| ast.fix_missing_locations(replacement) | |
| setter(replacement) | |
| try: | |
| instructions = self.compile_instructions() | |
| finally: | |
| setter(expr) | |
| if sys.version_info >= (3, 10): | |
| try: | |
| handle_jumps(instructions, original_instructions) | |
| except Exception: | |
| # Give other candidates a chance | |
| if TESTING or expr_index < len(exprs) - 1: | |
| continue | |
| raise | |
| indices = [ | |
| i | |
| for i, instruction in enumerate(instructions) | |
| if instruction.argval == sentinel | |
| ] | |
| # There can be several indices when the bytecode is duplicated, | |
| # as happens in a finally block in 3.9+ | |
| # First we remove the opcodes caused by our modifications | |
| for index_num, sentinel_index in enumerate(indices): | |
| # Adjustment for removing sentinel instructions below | |
| # in past iterations | |
| sentinel_index -= index_num * 2 | |
| assert_(instructions.pop(sentinel_index).opname == 'LOAD_CONST') | |
| assert_(instructions.pop(sentinel_index).opname == 'BINARY_POWER') | |
| # Then we see if any of the instruction indices match | |
| for index_num, sentinel_index in enumerate(indices): | |
| sentinel_index -= index_num * 2 | |
| new_index = sentinel_index - 1 | |
| if new_index != original_index: | |
| continue | |
| original_inst = original_instructions[original_index] | |
| new_inst = instructions[new_index] | |
| # In Python 3.9+, changing 'not x in y' to 'not sentinel_transformation(x in y)' | |
| # changes a CONTAINS_OP(invert=1) to CONTAINS_OP(invert=0),<sentinel stuff>,UNARY_NOT | |
| if ( | |
| original_inst.opname == new_inst.opname in ('CONTAINS_OP', 'IS_OP') | |
| and original_inst.arg != new_inst.arg # type: ignore[attr-defined] | |
| and ( | |
| original_instructions[original_index + 1].opname | |
| != instructions[new_index + 1].opname == 'UNARY_NOT' | |
| )): | |
| # Remove the difference for the upcoming assert | |
| instructions.pop(new_index + 1) | |
| # Check that the modified instructions don't have anything unexpected | |
| # 3.10 is a bit too weird to assert this in all cases but things still work | |
| if sys.version_info < (3, 10): | |
| for inst1, inst2 in zip_longest( | |
| original_instructions, instructions | |
| ): | |
| assert_(inst1 and inst2 and opnames_match(inst1, inst2)) | |
| yield expr | |
| def compile_instructions(self): | |
| # type: () -> List[EnhancedInstruction] | |
| module_code = compile_similar_to(self.tree, self.code) | |
| code = only(self.find_codes(module_code)) | |
| return self.clean_instructions(code) | |
| def find_codes(self, root_code): | |
| # type: (types.CodeType) -> list | |
| checks = [ | |
| attrgetter('co_firstlineno'), | |
| attrgetter('co_freevars'), | |
| attrgetter('co_cellvars'), | |
| lambda c: is_ipython_cell_code_name(c.co_name) or c.co_name, | |
| ] # type: List[Callable] | |
| if not self.is_pytest: | |
| checks += [ | |
| attrgetter('co_names'), | |
| attrgetter('co_varnames'), | |
| ] | |
| def matches(c): | |
| # type: (types.CodeType) -> bool | |
| return all( | |
| f(c) == f(self.code) | |
| for f in checks | |
| ) | |
| code_options = [] | |
| if matches(root_code): | |
| code_options.append(root_code) | |
| def finder(code): | |
| # type: (types.CodeType) -> None | |
| for const in code.co_consts: | |
| if not inspect.iscode(const): | |
| continue | |
| if matches(const): | |
| code_options.append(const) | |
| finder(const) | |
| finder(root_code) | |
| return code_options | |
| def get_actual_current_instruction(self, lasti): | |
| # type: (int) -> EnhancedInstruction | |
| """ | |
| Get the instruction corresponding to the current | |
| frame offset, skipping EXTENDED_ARG instructions | |
| """ | |
| # Don't use get_original_clean_instructions | |
| # because we need the actual instructions including | |
| # EXTENDED_ARG | |
| instructions = list(get_instructions(self.code)) | |
| index = only( | |
| i | |
| for i, inst in enumerate(instructions) | |
| if inst.offset == lasti | |
| ) | |
| while True: | |
| instruction = instructions[index] | |
| if instruction.opname != "EXTENDED_ARG": | |
| return instruction | |
| index += 1 | |
| def non_sentinel_instructions(instructions, start): | |
| # type: (List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction]] | |
| """ | |
| Yields (index, instruction) pairs excluding the basic | |
| instructions introduced by the sentinel transformation | |
| """ | |
| skip_power = False | |
| for i, inst in islice(enumerate(instructions), start, None): | |
| if inst.argval == sentinel: | |
| assert_(inst.opname == "LOAD_CONST") | |
| skip_power = True | |
| continue | |
| elif skip_power: | |
| assert_(inst.opname == "BINARY_POWER") | |
| skip_power = False | |
| continue | |
| yield i, inst | |
| def walk_both_instructions(original_instructions, original_start, instructions, start): | |
| # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction, int, EnhancedInstruction]] | |
| """ | |
| Yields matching indices and instructions from the new and original instructions, | |
| leaving out changes made by the sentinel transformation. | |
| """ | |
| original_iter = islice(enumerate(original_instructions), original_start, None) | |
| new_iter = non_sentinel_instructions(instructions, start) | |
| inverted_comparison = False | |
| while True: | |
| try: | |
| original_i, original_inst = next(original_iter) | |
| new_i, new_inst = next(new_iter) | |
| except StopIteration: | |
| return | |
| if ( | |
| inverted_comparison | |
| and original_inst.opname != new_inst.opname == "UNARY_NOT" | |
| ): | |
| new_i, new_inst = next(new_iter) | |
| inverted_comparison = ( | |
| original_inst.opname == new_inst.opname in ("CONTAINS_OP", "IS_OP") | |
| and original_inst.arg != new_inst.arg # type: ignore[attr-defined] | |
| ) | |
| yield original_i, original_inst, new_i, new_inst | |
| def handle_jumps(instructions, original_instructions): | |
| # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> None | |
| """ | |
| Transforms instructions in place until it looks more like original_instructions. | |
| This is only needed in 3.10+ where optimisations lead to more drastic changes | |
| after the sentinel transformation. | |
| Replaces JUMP instructions that aren't also present in original_instructions | |
| with the sections that they jump to until a raise or return. | |
| In some other cases duplication found in `original_instructions` | |
| is replicated in `instructions`. | |
| """ | |
| while True: | |
| for original_i, original_inst, new_i, new_inst in walk_both_instructions( | |
| original_instructions, 0, instructions, 0 | |
| ): | |
| if opnames_match(original_inst, new_inst): | |
| continue | |
| if "JUMP" in new_inst.opname and "JUMP" not in original_inst.opname: | |
| # Find where the new instruction is jumping to, ignoring | |
| # instructions which have been copied in previous iterations | |
| start = only( | |
| i | |
| for i, inst in enumerate(instructions) | |
| if inst.offset == new_inst.argval | |
| and not getattr(inst, "_copied", False) | |
| ) | |
| # Replace the jump instruction with the jumped to section of instructions | |
| # That section may also be deleted if it's not similarly duplicated | |
| # in original_instructions | |
| new_instructions = handle_jump( | |
| original_instructions, original_i, instructions, start | |
| ) | |
| assert new_instructions is not None | |
| instructions[new_i : new_i + 1] = new_instructions | |
| else: | |
| # Extract a section of original_instructions from original_i to return/raise | |
| orig_section = [] | |
| for section_inst in original_instructions[original_i:]: | |
| orig_section.append(section_inst) | |
| if section_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"): | |
| break | |
| else: | |
| # No return/raise - this is just a mismatch we can't handle | |
| raise AssertionError | |
| instructions[new_i:new_i] = only(find_new_matching(orig_section, instructions)) | |
| # instructions has been modified, the for loop can't sensibly continue | |
| # Restart it from the beginning, checking for other issues | |
| break | |
| else: # No mismatched jumps found, we're done | |
| return | |
| def find_new_matching(orig_section, instructions): | |
| # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> Iterator[List[EnhancedInstruction]] | |
| """ | |
| Yields sections of `instructions` which match `orig_section`. | |
| The yielded sections include sentinel instructions, but these | |
| are ignored when checking for matches. | |
| """ | |
| for start in range(len(instructions) - len(orig_section)): | |
| indices, dup_section = zip( | |
| *islice( | |
| non_sentinel_instructions(instructions, start), | |
| len(orig_section), | |
| ) | |
| ) | |
| if len(dup_section) < len(orig_section): | |
| return | |
| if sections_match(orig_section, dup_section): | |
| yield instructions[start:indices[-1] + 1] | |
| def handle_jump(original_instructions, original_start, instructions, start): | |
| # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Optional[List[EnhancedInstruction]] | |
| """ | |
| Returns the section of instructions starting at `start` and ending | |
| with a RETURN_VALUE or RAISE_VARARGS instruction. | |
| There should be a matching section in original_instructions starting at original_start. | |
| If that section doesn't appear elsewhere in original_instructions, | |
| then also delete the returned section of instructions. | |
| """ | |
| for original_j, original_inst, new_j, new_inst in walk_both_instructions( | |
| original_instructions, original_start, instructions, start | |
| ): | |
| assert_(opnames_match(original_inst, new_inst)) | |
| if original_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"): | |
| inlined = deepcopy(instructions[start : new_j + 1]) | |
| for inl in inlined: | |
| inl._copied = True | |
| orig_section = original_instructions[original_start : original_j + 1] | |
| if not check_duplicates( | |
| original_start, orig_section, original_instructions | |
| ): | |
| instructions[start : new_j + 1] = [] | |
| return inlined | |
| return None | |
| def check_duplicates(original_i, orig_section, original_instructions): | |
| # type: (int, List[EnhancedInstruction], List[EnhancedInstruction]) -> bool | |
| """ | |
| Returns True if a section of original_instructions starting somewhere other | |
| than original_i and matching orig_section is found, i.e. orig_section is duplicated. | |
| """ | |
| for dup_start in range(len(original_instructions)): | |
| if dup_start == original_i: | |
| continue | |
| dup_section = original_instructions[dup_start : dup_start + len(orig_section)] | |
| if len(dup_section) < len(orig_section): | |
| return False | |
| if sections_match(orig_section, dup_section): | |
| return True | |
| return False | |
| def sections_match(orig_section, dup_section): | |
| # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> bool | |
| """ | |
| Returns True if the given lists of instructions have matching linenos and opnames. | |
| """ | |
| return all( | |
| ( | |
| orig_inst.lineno == dup_inst.lineno | |
| # POP_BLOCKs have been found to have differing linenos in innocent cases | |
| or "POP_BLOCK" == orig_inst.opname == dup_inst.opname | |
| ) | |
| and opnames_match(orig_inst, dup_inst) | |
| for orig_inst, dup_inst in zip(orig_section, dup_section) | |
| ) | |
| def opnames_match(inst1, inst2): | |
| # type: (Instruction, Instruction) -> bool | |
| return ( | |
| inst1.opname == inst2.opname | |
| or "JUMP" in inst1.opname | |
| and "JUMP" in inst2.opname | |
| or (inst1.opname == "PRINT_EXPR" and inst2.opname == "POP_TOP") | |
| or ( | |
| inst1.opname in ("LOAD_METHOD", "LOOKUP_METHOD") | |
| and inst2.opname == "LOAD_ATTR" | |
| ) | |
| or (inst1.opname == "CALL_METHOD" and inst2.opname == "CALL_FUNCTION") | |
| ) | |
| def get_setter(node): | |
| # type: (EnhancedAST) -> Optional[Callable[[ast.AST], None]] | |
| parent = node.parent | |
| for name, field in ast.iter_fields(parent): | |
| if field is node: | |
| def setter(new_node): | |
| # type: (ast.AST) -> None | |
| return setattr(parent, name, new_node) | |
| return setter | |
| elif isinstance(field, list): | |
| for i, item in enumerate(field): | |
| if item is node: | |
| def setter(new_node): | |
| # type: (ast.AST) -> None | |
| field[i] = new_node | |
| return setter | |
| return None | |
| lock = RLock() | |
| def statement_containing_node(node): | |
| # type: (ast.AST) -> EnhancedAST | |
| while not isinstance(node, ast.stmt): | |
| node = cast(EnhancedAST, node).parent | |
| return cast(EnhancedAST, node) | |
| def assert_linenos(tree): | |
| # type: (ast.AST) -> Iterator[int] | |
| for node in ast.walk(tree): | |
| if ( | |
| hasattr(node, 'parent') and | |
| isinstance(statement_containing_node(node), ast.Assert) | |
| ): | |
| for lineno in node_linenos(node): | |
| yield lineno | |
| def _extract_ipython_statement(stmt): | |
| # type: (EnhancedAST) -> ast.Module | |
| # IPython separates each statement in a cell to be executed separately | |
| # So NodeFinder should only compile one statement at a time or it | |
| # will find a code mismatch. | |
| while not isinstance(stmt.parent, ast.Module): | |
| stmt = stmt.parent | |
| # use `ast.parse` instead of `ast.Module` for better portability | |
| # python3.8 changes the signature of `ast.Module` | |
| # Inspired by https://github.com/pallets/werkzeug/pull/1552/files | |
| tree = ast.parse("") | |
| tree.body = [cast(ast.stmt, stmt)] | |
| ast.copy_location(tree, stmt) | |
| return tree | |
| def is_ipython_cell_code_name(code_name): | |
| # type: (str) -> bool | |
| return bool(re.match(r"(<module>|<cell line: \d+>)$", code_name)) | |
| def is_ipython_cell_filename(filename): | |
| # type: (str) -> bool | |
| return bool(re.search(r"<ipython-input-|[/\\]ipykernel_\d+[/\\]", filename)) | |
| def is_ipython_cell_code(code_obj): | |
| # type: (types.CodeType) -> bool | |
| return ( | |
| is_ipython_cell_filename(code_obj.co_filename) and | |
| is_ipython_cell_code_name(code_obj.co_name) | |
| ) | |
| def find_node_ipython(frame, lasti, stmts, source): | |
| # type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]] | |
| node = decorator = None | |
| for stmt in stmts: | |
| tree = _extract_ipython_statement(stmt) | |
| try: | |
| node_finder = NodeFinder(frame, stmts, tree, lasti, source) | |
| if (node or decorator) and (node_finder.result or node_finder.decorator): | |
| # Found potential nodes in separate statements, | |
| # cannot resolve ambiguity, give up here | |
| return None, None | |
| node = node_finder.result | |
| decorator = node_finder.decorator | |
| except Exception: | |
| pass | |
| return decorator, node | |
| def attr_names_match(attr, argval): | |
| # type: (str, str) -> bool | |
| """ | |
| Checks that the user-visible attr (from ast) can correspond to | |
| the argval in the bytecode, i.e. the real attribute fetched internally, | |
| which may be mangled for private attributes. | |
| """ | |
| if attr == argval: | |
| return True | |
| if not attr.startswith("__"): | |
| return False | |
| return bool(re.match(r"^_\w+%s$" % attr, argval)) | |
| def node_linenos(node): | |
| # type: (ast.AST) -> Iterator[int] | |
| if hasattr(node, "lineno"): | |
| linenos = [] # type: Sequence[int] | |
| if hasattr(node, "end_lineno") and isinstance(node, ast.expr): | |
| assert node.end_lineno is not None # type: ignore[attr-defined] | |
| linenos = range(node.lineno, node.end_lineno + 1) # type: ignore[attr-defined] | |
| else: | |
| linenos = [node.lineno] # type: ignore[attr-defined] | |
| for lineno in linenos: | |
| yield lineno | |
| if sys.version_info >= (3, 11): | |
| from ._position_node_finder import PositionNodeFinder as NodeFinder | |
| else: | |
| NodeFinder = SentinelNodeFinder | |