diff --git a/flake8_pyi/visitor.py b/flake8_pyi/visitor.py index 4bd8aeb..059f846 100644 --- a/flake8_pyi/visitor.py +++ b/flake8_pyi/visitor.py @@ -4,6 +4,7 @@ import ast import re import sys +import types from collections import Counter, defaultdict from collections.abc import Container, Iterable, Iterator, Sequence, Set as AbstractSet from contextlib import contextmanager @@ -179,18 +180,17 @@ def _is_object(node: ast.AST | None, name: str, *, from_: Container[str]) -> boo >>> _is_AsyncIterator(_ast_node_for("collections.abc.AsyncIterator")) True """ - if _is_name(node, name): - return True - if not (isinstance(node, ast.Attribute) and node.attr == name): - return False - node_value = node.value - if isinstance(node_value, ast.Name): - return node_value.id in from_ - return ( - isinstance(node_value, ast.Attribute) - and isinstance(node_value.value, ast.Name) - and f"{node_value.value.id}.{node_value.attr}" in from_ - ) + match node: + case ast.Name(id): + return id == name + case ast.Attribute(value=ast.Name(id), attr=attr): + return attr == name and id in from_ + case ast.Attribute( + value=ast.Attribute(value=ast.Name(id), attr=inner_attr), attr=attr + ): + return attr == name and f"{id}.{inner_attr}" in from_ + case _: + return False _is_BaseException = partial(_is_object, name="BaseException", from_={"builtins"}) @@ -261,18 +261,19 @@ def _get_name_of_class_if_from_modules( >>> _get_name_of_class_if_from_modules(int_node, modules={'typing'}) is None True """ - if isinstance(classnode, ast.Name): - return classnode.id - if isinstance(classnode, ast.Attribute): - module_node = classnode.value - if isinstance(module_node, ast.Name) and module_node.id in modules: - return classnode.attr - if ( - isinstance(module_node, ast.Attribute) - and isinstance(module_node.value, ast.Name) - and f"{module_node.value.id}.{module_node.attr}" in modules + match classnode: + case ast.Name(id): + return id + case ast.Attribute(value=ast.Name(id=module_name), attr=attr): + if module_name in modules: + return attr + case ast.Attribute( + value=ast.Attribute(value=ast.Name(id=module_name), attr=inner_attr), + attr=attr, ): - return classnode.attr + if f"{module_name}.{inner_attr}" in modules: + return attr + return None @@ -620,28 +621,32 @@ def _analyze_classdef(node: ast.ClassDef) -> EnclosingClassContext: bases_map: defaultdict[str, set[str | None]] = defaultdict(set) def _unravel(node: ast.expr) -> str | None: - if isinstance(node, ast.Name): - return node.id - if isinstance(node, ast.Attribute): - value = _unravel(node.value) - if value is None: + match node: + case ast.Name(id): + return id + case ast.Attribute(value=value, attr=attr): + value_str = _unravel(value) + if value_str is None: + return None + return f"{value_str}.{attr}" + case _: return None - return f"{value}.{node.attr}" - return None def _analyze_base_node( base_node: ast.expr, top_level: bool = True ) -> ClassBase | None: - if isinstance(base_node, ast.Name): - return ClassBase(None, base_node.id) - if isinstance(base_node, ast.Attribute): - value = _unravel(base_node.value) - if value is None: + match base_node: + case ast.Name(id): + return ClassBase(None, id) + case ast.Attribute(value=value, attr=attr): + unravelled = _unravel(value) + if unravelled is None: + return None + return ClassBase(unravelled, attr) + case ast.Subscript() if top_level: + return _analyze_base_node(base_node.value, top_level=False) + case _: return None - return ClassBase(value, base_node.attr) - if isinstance(base_node, ast.Subscript) and top_level: - return _analyze_base_node(base_node.value, top_level=False) - return None for base_node in node.bases: base = _analyze_base_node(base_node) @@ -668,84 +673,69 @@ def _is_valid_default_value_with_annotation( the validity of default values for ast.AnnAssign nodes. (E.g. `foo: int = 5` is OK, but `foo: TypeVar = TypeVar("foo")` is not.) """ - # lists, tuples, sets - if isinstance(node, (ast.List, ast.Tuple, ast.Set)): - return ( - allow_containers - and len(node.elts) <= 10 - and all( - _is_valid_default_value_with_annotation(elt, allow_containers=False) - for elt in node.elts + match node: + case ast.List(elts) | ast.Tuple(elts) | ast.Set(elts): + return ( + allow_containers + and len(elts) <= 10 + and all( + _is_valid_default_value_with_annotation(elt, allow_containers=False) + for elt in elts + ) ) - ) - # dicts - if isinstance(node, ast.Dict): - return ( - allow_containers - and len(node.keys) <= 10 - and all( - ( - subnode is not None - and _is_valid_default_value_with_annotation( - subnode, allow_containers=False + case ast.Dict(keys, values): + return ( + allow_containers + and len(keys) <= 10 + and all( + ( + subnode is not None + and _is_valid_default_value_with_annotation( + subnode, allow_containers=False + ) ) + for subnode in chain(keys, values) ) - for subnode in chain(node.keys, node.values) ) - ) - # `...`, bools, None, str, bytes, - # positive ints, positive floats, positive complex numbers with no real part - if isinstance(node, ast.Constant): - return True - - # Negative ints, negative floats, negative complex numbers with no real part, - # some constants from the math module - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - numeric_types = {int, float, complex} - if ( - isinstance(node.operand, ast.Constant) - and type(node.operand.value) in numeric_types - ): + # `...`, bools, None, str, bytes, + # positive ints, positive floats, positive complex numbers with no real part + case ast.Constant(): return True - if isinstance(node.operand, ast.Attribute) and isinstance( - node.operand.value, ast.Name + + # Negative ints, negative floats, negative complex numbers with no real part + case ast.UnaryOp(op=ast.USub(), operand=ast.Constant(value=value)): + return type(value) in {int, float, complex} + + # some constants from the math module + case ast.UnaryOp( + op=ast.USub(), operand=ast.Attribute(value=ast.Name(id), attr=attr) ): - fullname = f"{node.operand.value.id}.{node.operand.attr}" - return fullname in _NEGATABLE_MATH_ATTRIBUTES_IN_DEFAULTS - return False + return f"{id}.{attr}" in _NEGATABLE_MATH_ATTRIBUTES_IN_DEFAULTS - # Complex numbers with a real part and an imaginary part... - if ( - isinstance(node, ast.BinOp) - and isinstance(node.op, (ast.Add, ast.Sub)) - and isinstance(node.right, ast.Constant) - and type(node.right.value) is complex - ): - left = node.left - # ...Where the real part is positive: - if isinstance(left, ast.Constant) and type(left.value) in {int, float}: - return True - # ...Where the real part is negative: - if ( - isinstance(left, ast.UnaryOp) - and isinstance(left.op, ast.USub) - and isinstance(left.operand, ast.Constant) - and type(left.operand.value) in {int, float} + # Complex numbers with a real part and an imaginary part... + case ast.BinOp( + left=left, op=ast.Add() | ast.Sub(), right=ast.Constant(value=complex()) ): + match left: + case ast.Constant(value): + return type(value) in {int, float} + case ast.UnaryOp(op=ast.USub(), operand=ast.Constant(value)): + return type(value) in {int, float} + case _: + return False + + # Attribute access like math.inf or enums + case ast.Attribute(): return True - return False - - # Attribute access like math.inf or enums - if isinstance(node, ast.Attribute): - return True - # Special cases - if isinstance(node, ast.Name): - return node.id in _ALLOWED_SIMPLE_NAMES_IN_DEFAULTS + # Special cases + case ast.Name(id): + return id in _ALLOWED_SIMPLE_NAMES_IN_DEFAULTS - return False + case _: + return False def _is_valid_pep_604_union_member(node: ast.expr) -> bool: @@ -767,11 +757,13 @@ def _is_valid_pep_604_union(node: ast.expr) -> TypeGuard[ast.BinOp]: def _is_valid_default_value_without_annotation(node: ast.expr) -> bool: """Is `node` a valid default for an assignment without an annotation?""" - return ( - isinstance(node, (ast.Call, ast.Name, ast.Attribute, ast.Subscript)) - or (isinstance(node, ast.Constant) and node.value in {None, ...}) - or _is_valid_pep_604_union(node) - ) + match node: + case ast.Call() | ast.Name() | ast.Attribute() | ast.Subscript(): + return True + case ast.Constant(value): + return value in {None, ...} + case _: + return _is_valid_pep_604_union(node) def _check_import_or_attribute( @@ -1104,22 +1096,20 @@ def visit_Call(self, node: ast.Call) -> None: self.visit(arg) def visit_Constant(self, node: ast.Constant) -> None: - if isinstance(node.value, str) and not self.string_literals_allowed.active: - self.error(node, errors.Y020) - elif ( - isinstance(node.value, (str, bytes)) - and not self.long_strings_allowed.active - ): - if len(node.value) > 50: - self.error(node, errors.Y053) - elif isinstance(node.value, (int, float, complex)): - if len(str(node.value)) > 10: - # The maximum character limit is arbitrary, but here's what it's based on: - # Hex representation of 32-bit integers tend to be 10 chars. - # So is the decimal representation - # of the maximum positive signed 32-bit integer. - # 0xFFFFFFFF --> 4294967295 - self.error(node, errors.Y054) + match node.value: + case str() if not self.string_literals_allowed.active: + self.error(node, errors.Y020) + case str() | bytes() if not self.long_strings_allowed.active: + if len(node.value) > 50: + self.error(node, errors.Y053) + case int() | float() | complex(): + if len(str(node.value)) > 10: + # The maximum character limit is arbitrary, but here's what it's based on: + # Hex representation of 32-bit integers tend to be 10 chars. + # So is the decimal representation + # of the maximum positive signed 32-bit integer. + # 0xFFFFFFFF --> 4294967295 + self.error(node, errors.Y054) def visit_Expr(self, node: ast.Expr) -> None: if isinstance(node.value, ast.Constant) and isinstance(node.value.value, str): @@ -1456,20 +1446,16 @@ def _check_if_expression(self, node: ast.expr) -> None: # mypy doesn't support chained comparisons self.error(node, errors.Y002) return - if isinstance(node.left, ast.Subscript): - self._check_subscript_version_check(node) - elif isinstance(node.left, ast.Attribute): - if _is_name(node.left.value, "sys"): - if node.left.attr == "platform": - self._check_platform_check(node) - elif node.left.attr == "version_info": - self._check_version_check(node) - else: - self.error(node, errors.Y002) - else: + + match node.left: + case ast.Subscript(): + self._check_subscript_version_check(node) + case ast.Attribute(value=ast.Name("sys"), attr="platform"): + self._check_platform_check(node) + case ast.Attribute(value=ast.Name("sys"), attr="version_info"): + self._check_version_check(node) + case _: self.error(node, errors.Y002) - else: - self.error(node, errors.Y002) def _check_for_Y066_violations(self, node: ast.If) -> None: def is_version_info(attr: ast.expr) -> bool: @@ -1506,32 +1492,21 @@ def _check_subscript_version_check(self, node: ast.Compare) -> None: can_have_strict_equals: int | None = None version_info = node.left if isinstance(version_info, ast.Subscript): - slc = version_info.slice - if isinstance(slc, ast.Constant): - # anything other than the integer 0 doesn't make much sense - if type(slc.value) is int and slc.value == 0: - must_be_single = True - else: + match version_info.slice: + case ast.Constant(value): + # anything other than the integer 0 doesn't make much sense + if type(value) is int and value == 0: + must_be_single = True + else: + self.error(node, errors.Y003) + return + case ast.Slice( + lower=None, upper=ast.Constant(value), step=None + ) if type(value) is int and value in {1, 2}: + can_have_strict_equals = value + case _: self.error(node, errors.Y003) return - elif isinstance(slc, ast.Slice): - if slc.lower is not None or slc.step is not None: - self.error(node, errors.Y003) - return - elif ( - # allow only [:1] and [:2] - isinstance(slc.upper, ast.Constant) - and type(slc.upper.value) is int - and slc.upper.value in {1, 2} - ): - can_have_strict_equals = slc.upper.value - else: - self.error(node, errors.Y003) - return - else: - # extended slicing - self.error(node, errors.Y003) - return self._check_version_check( node, must_be_single=must_be_single, @@ -1564,33 +1539,24 @@ def _check_version_check( # mypy only supports major and minor version checks self.error(node, errors.Y004) - cmpop = node.ops[0] - if isinstance(cmpop, (ast.Lt, ast.GtE)): - pass - elif isinstance(cmpop, (ast.Eq, ast.NotEq)): - if can_have_strict_equals is not None: + match node.ops[0]: + case ast.Lt() | ast.GtE(): + pass + case ast.Eq() | ast.NotEq() if can_have_strict_equals is not None: if len(comparator.elts) != can_have_strict_equals: self.error(node, errors.Y005.format(n=can_have_strict_equals)) - else: + case _: self.error(node, errors.Y006) - else: - self.error(node, errors.Y006) def _check_platform_check(self, node: ast.Compare) -> None: - cmpop = node.ops[0] - # "in" might also make sense but we don't currently have one - if not isinstance(cmpop, (ast.Eq, ast.NotEq)): - self.error(node, errors.Y007) - return - - comparator = node.comparators[0] - if isinstance(comparator, ast.Constant) and type(comparator.value) is str: - # other values are possible but we don't need them right now - # this protects against typos - if comparator.value not in {"linux", "win32", "cygwin", "darwin"}: - self.error(node, errors.Y008.format(platform=comparator.value)) - else: - self.error(node, errors.Y007) + match node: + case ast.Compare( + comparators=[ast.Constant(value), *_], ops=[ast.Eq() | ast.NotEq(), *_] + ) if (type(value) is str): + if value not in {"linux", "win32", "cygwin", "darwin"}: + self.error(node, errors.Y008.format(platform=value)) + case _: + self.error(node, errors.Y007) def _check_class_bases(self, bases: list[ast.expr]) -> None: Y040_encountered = False @@ -1636,29 +1602,23 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: def check_class_pass_and_ellipsis(self, node: ast.ClassDef) -> None: # empty class body should contain "..." not "pass" - if len(node.body) == 1: - statement = node.body[0] - if ( - isinstance(statement, ast.Expr) - and isinstance(statement.value, ast.Constant) - and statement.value.value is ... - ): + match node.body: + case [ast.Expr(value=ast.Constant(value=types.EllipsisType()))]: return - elif isinstance(statement, ast.Pass): - self.error(statement, errors.Y009) + case [ast.Pass()]: + self.error(node.body[0], errors.Y009) return + case _: + pass for statement in node.body: - # "pass" should not used in class body - if isinstance(statement, ast.Pass): - self.error(statement, errors.Y012) - # "..." should not be used in non-empty class body - elif ( - isinstance(statement, ast.Expr) - and isinstance(statement.value, ast.Constant) - and statement.value.value is ... - ): - self.error(statement, errors.Y013) + match statement: + case ast.Pass(): + # "pass" should not used in class body + self.error(statement, errors.Y012) + case ast.Expr(value=ast.Constant(value)) if value is ...: + # "..." should not be used in non-empty class body + self.error(statement, errors.Y013) def _check_exit_method( # noqa: C901 self, node: ast.FunctionDef | ast.AsyncFunctionDef, method_name: str @@ -2074,21 +2034,17 @@ def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: if node.name != "__getattr__" and node.returns and _is_Incomplete(node.returns): self.error(node.returns, errors.Y065.format(what="return type")) - body = node.body - if len(body) > 1: - self.error(body[1], errors.Y048) - elif body: - statement = body[0] + match node.body: + case [_, stm2t2, *_]: + self.error(stm2t2, errors.Y048) # normally, should just be "..." - if isinstance(statement, ast.Pass): - self.error(statement, errors.Y009) + case [ast.Pass()]: + self.error(node.body[0], errors.Y009) # ... is fine. Docstrings are not but we produce # tailored error message for them elsewhere. - elif not ( - isinstance(statement, ast.Expr) - and isinstance(statement.value, ast.Constant) - and isinstance(statement.value.value, (str, type(...))) - ): + case [ast.Expr(value=ast.Constant(value=str() | types.EllipsisType()))]: + pass + case [statement]: self.error(statement, errors.Y010) self._check_pep570_syntax_used_where_applicable(node)