Skip to content

Commit 218e78a

Browse files
committed
Correctly handle @overload decorated functions.
1 parent cc2e185 commit 218e78a

3 files changed

Lines changed: 50 additions & 9 deletions

File tree

flake8_dunder_all/__init__.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# stdlib
3333
import ast
3434
import sys
35-
from typing import Any, Generator, List, Tuple, Type, Union
35+
from typing import Any, Generator, List, Set, Tuple, Type, Union
3636

3737
# 3rd party
3838
from consolekit.terminal_colours import Fore
@@ -64,11 +64,12 @@ class Visitor(ast.NodeVisitor):
6464

6565
found_all: bool #: Flag to indicate a ``__all__`` variable has been found in the AST.
6666
last_import: int #: The lineno of the last top-level import
67-
members: List[str] #: List of functions and classed defined in the AST
67+
members: Set[str] #: List of functions and classed defined in the AST
68+
use_endlineno: bool
6869

6970
def __init__(self, use_endlineno: bool = False) -> None:
7071
self.found_all = False
71-
self.members: List[str] = []
72+
self.members = set()
7273
self.last_import = 0
7374
self.use_endlineno = use_endlineno
7475

@@ -91,8 +92,30 @@ def handle_def(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Clas
9192
:param node: The node being visited.
9293
"""
9394

94-
if not node.name.startswith('_'):
95-
self.members.append(node.name)
95+
decorators = []
96+
97+
for deco in node.decorator_list:
98+
if isinstance(deco, ast.Name):
99+
decorators.append(deco.id)
100+
elif isinstance(deco, ast.Attribute):
101+
parts = [deco.attr]
102+
103+
# last_part = deco.value
104+
#
105+
# while True:
106+
# if isinstance(last_part, ast.Attribute):
107+
# parts.append(last_part.attr)
108+
# last_part = last_part.value
109+
# elif isinstance(last_part, ast.Name):
110+
# parts.append(last_part.id)
111+
# break
112+
# else:
113+
# break
114+
115+
decorators.append('.'.join(reversed(parts)))
116+
117+
if not node.name.startswith('_') and "overload" not in decorators:
118+
self.members.add(node.name)
96119

97120
def visit_FunctionDef(self, node: ast.FunctionDef):
98121
"""
@@ -236,7 +259,7 @@ def check_and_add_all(filename: PathPlus, quote_type: str = '"') -> int:
236259
if not visitor.members:
237260
return 0
238261

239-
members = repr(visitor.members).replace(bad_quote, quote_type)
262+
members = repr(sorted(visitor.members)).replace(bad_quote, quote_type)
240263

241264
lines = filename.read_text().split('\n')
242265

tests/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@ async def a_function(): ...
100100
async def a_function(): ...
101101
"""
102102

103+
testing_source_k = '''
104+
"""a docstring"""
105+
106+
@overload
107+
def a_function(): ...
108+
'''
109+
110+
testing_source_l = '''
111+
"""a docstring"""
112+
113+
@typing.overload
114+
def a_function(): ...
115+
'''
116+
103117
mangled_source = '''
104118
"""a docstring
105119
import foo

tests/test_flake8_dunder_all.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
testing_source_g,
2525
testing_source_h,
2626
testing_source_i,
27-
testing_source_j
27+
testing_source_j,
28+
testing_source_k,
29+
testing_source_l
2830
)
2931

3032

@@ -84,7 +86,7 @@ def test_visitor(source, members, found_all, last_import):
8486
visitor = Visitor()
8587
visitor.visit(ast.parse(source))
8688

87-
assert visitor.members == members
89+
assert sorted(visitor.members) == members
8890
assert visitor.found_all is found_all
8991
assert visitor.last_import is last_import
9092

@@ -121,7 +123,7 @@ def test_visitor_endlineno(source, members, found_all, last_import):
121123
mark_text_ranges(tree, source)
122124
visitor.visit(tree)
123125

124-
assert visitor.members == members
126+
assert sorted(visitor.members) == members
125127
assert visitor.found_all is found_all
126128
assert visitor.last_import is last_import
127129

@@ -144,6 +146,8 @@ def test_visitor_endlineno(source, members, found_all, last_import):
144146
pytest.param(testing_source_g, ["a_function"], 1, id="async function no __all__"),
145147
pytest.param(testing_source_h, [], 0, id="from import"),
146148
pytest.param(testing_source_i, [], 1, id="lots of lines"),
149+
pytest.param(testing_source_k, [], 0, id="overload"),
150+
pytest.param(testing_source_l, [], 0, id="typing.overload"),
147151
]
148152
)
149153
def test_check_and_add_all(tmpdir, source, members: List[str], ret):

0 commit comments

Comments
 (0)