Skip to content

Commit 9d314bf

Browse files
committed
Correctly handle multi-line imports.
1 parent 60e4839 commit 9d314bf

9 files changed

Lines changed: 151 additions & 30 deletions

File tree

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ repos:
2828
args: [--allow-git]
2929
- id: check-docstring-first
3030

31+
- repo: https://github.com/domdfcoding/flake8-dunder-all
32+
rev: v0.0.3
33+
hooks:
34+
- id: ensure-dunder-all
35+
files: ^flake8_dunder_all/.*\.py$
36+
3137
- repo: https://github.com/pre-commit/pygrep-hooks
3238
rev: v1.5.1
3339
hooks:

doc-source/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
autodocsumm>=0.2.0
12
default_values>=0.0.7
23
domdf_sphinx_theme>=0.0.11
34
extras_require
@@ -11,4 +12,3 @@ sphinxcontrib-httpdomain>=1.7.0
1112
sphinxemoji>=0.1.6
1213
toctree_plus>=0.0.2
1314
git+git://github.com/domdfcoding/sphinx-autodoc-typehints.git@typevar-as-pydata
14-
git+git://github.com/Chilipp/autodocsumm@0a1b6515ba83deb70eb2c356acf956f448536c90

flake8_dunder_all/__init__.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,16 @@
3131

3232
# stdlib
3333
import ast
34-
from typing import Any, Generator, List, Tuple, Type
34+
import sys
35+
from typing import Any, Generator, List, Tuple, Type, Union
3536

3637
# 3rd party
3738
from domdf_python_tools.paths import PathPlus
3839
from domdf_python_tools.terminal_colours import Fore
3940
from domdf_python_tools.utils import stderr_writer
4041

4142
# this package
42-
from flake8_dunder_all.utils import get_docstring_lineno
43+
from flake8_dunder_all.utils import get_docstring_lineno, mark_text_ranges
4344

4445
__author__: str = "Dominic Davis-Foster"
4546
__copyright__: str = "2020 Dominic Davis-Foster"
@@ -55,16 +56,21 @@
5556
class Visitor(ast.NodeVisitor):
5657
"""
5758
AST :class:`~ast.NodeVisitor` to check a module has defined ``__all__``, and add one if it not.
59+
60+
:param use_endlineno: Flag to indicate whether the end_lineno functionality is available.
61+
This functionality is available on Python 3.8 and above, or when the tree has been passed through
62+
:func:`flake8_dunder_all.utils.mark_text_ranges``.
5863
"""
5964

60-
def __init__(self) -> None:
61-
self.found_all = False
65+
found_all: bool #: Flag to indicate a ``__all__`` variable has been found in the AST.
66+
last_import: int #: The lineno of the last top-level import
67+
members: List[str] #: List of functions and classed defined in the AST
6268

63-
# List of functions and classed defined in this module
69+
def __init__(self, use_endlineno: bool = False) -> None:
70+
self.found_all = False
6471
self.members: List[str] = []
65-
66-
# Lineno of last top-level import
6772
self.last_import = 0
73+
self.use_endlineno = use_endlineno
6874

6975
def visit_Name(self, node: ast.Name):
7076
"""
@@ -78,6 +84,16 @@ def visit_Name(self, node: ast.Name):
7884
else:
7985
self.generic_visit(node)
8086

87+
def handle_def(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]):
88+
"""
89+
Handles ``def foo(): ...``, ``async def foo(): ...`` and ``class Foo: ...``.
90+
91+
:param node: The node being visited.
92+
"""
93+
94+
if not node.name.startswith("_"):
95+
self.members.append(node.name)
96+
8197
def visit_FunctionDef(self, node: ast.FunctionDef):
8298
"""
8399
Visit ``def foo(): ...``.
@@ -86,8 +102,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
86102
"""
87103

88104
# Don't generic visit
89-
if not node.name.startswith("_"):
90-
self.members.append(node.name)
105+
self.handle_def(node)
91106

92107
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
93108
"""
@@ -97,8 +112,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
97112
"""
98113

99114
# Don't generic visit
100-
if not node.name.startswith("_"):
101-
self.members.append(node.name)
115+
self.handle_def(node)
102116

103117
def visit_ClassDef(self, node: ast.ClassDef):
104118
"""
@@ -108,8 +122,21 @@ def visit_ClassDef(self, node: ast.ClassDef):
108122
"""
109123

110124
# Don't generic visit
111-
if not node.name.startswith("_"):
112-
self.members.append(node.name)
125+
self.handle_def(node)
126+
127+
def handle_import(self, node: Union[ast.Import, ast.ImportFrom]):
128+
"""
129+
Handles ``import foo`` and ``from foo import bar``.
130+
131+
:param node: The node being visited
132+
"""
133+
134+
if self.use_endlineno:
135+
if not node.col_offset and node.end_lineno > self.last_import: # type: ignore
136+
self.last_import = node.end_lineno # type: ignore
137+
else:
138+
if not node.col_offset and node.lineno > self.last_import:
139+
self.last_import = node.lineno
113140

114141
def visit_Import(self, node: ast.Import):
115142
"""
@@ -119,8 +146,7 @@ def visit_Import(self, node: ast.Import):
119146
"""
120147

121148
# Don't generic visit
122-
if not node.col_offset and node.lineno > self.last_import:
123-
self.last_import = node.lineno
149+
self.handle_import(node)
124150

125151
def visit_ImportFrom(self, node: ast.ImportFrom):
126152
"""
@@ -130,8 +156,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
130156
"""
131157

132158
# Don't generic visit
133-
if not node.col_offset and node.lineno > self.last_import:
134-
self.last_import = node.lineno
159+
self.handle_import(node)
135160

136161

137162
class Plugin:
@@ -188,12 +213,15 @@ def check_and_add_all(filename: PathPlus, quote_type: str = '"') -> int:
188213
filename = PathPlus(filename)
189214

190215
try:
191-
tree = ast.parse(filename.read_text())
216+
source = filename.read_text()
217+
tree = ast.parse(source)
218+
if sys.version_info < (3, 8): # pragma: no cover (<py38)
219+
mark_text_ranges(tree, source)
192220
except SyntaxError:
193221
stderr_writer(Fore.RED(f"'{filename}' does not appear to be a valid Python source file."))
194222
return 4
195223

196-
visitor = Visitor()
224+
visitor = Visitor(use_endlineno=True)
197225
visitor.visit(tree)
198226

199227
if visitor.found_all:

flake8_dunder_all/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
5151
5252
"""
5353
parser = argparse.ArgumentParser(
54-
description=tidy_docstring(main.__doc__), formatter_class=argparse.RawTextHelpFormatter
54+
description=tidy_docstring(main.__doc__),
55+
formatter_class=argparse.RawTextHelpFormatter,
5556
)
5657
parser.add_argument('filenames', type=str, nargs='*', help="The filename(s) to lint.", metavar="FILENAME")
5758
parser.add_argument(

flake8_dunder_all/utils.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,21 @@
3232
# Copyright © 1995-2000 Corporation for National Research Initiatives. All rights reserved.
3333
# Copyright © 1991-1995 Stichting Mathematisch Centrum. All rights reserved.
3434
#
35+
# mark_text_ranges from Thonny
36+
# https://github.com/thonny/thonny/blob/master/thonny/ast_utils.py
37+
# Copyright (c) 2020 Aivar Annamaa
38+
# MIT Licensed
3539

3640
# stdlib
3741
import ast
3842
import re
39-
from textwrap import dedent, indent
43+
from textwrap import dedent
4044
from typing import Optional, Union
4145

42-
__all__ = ["get_docstring_lineno", "tidy_docstring"]
46+
# 3rd party
47+
from asttokens.asttokens import ASTTokens
48+
49+
__all__ = ["get_docstring_lineno", "tidy_docstring", "mark_text_ranges"]
4350

4451

4552
def get_docstring_lineno(node: Union[ast.FunctionDef, ast.ClassDef, ast.Module]) -> Optional[int]:
@@ -76,3 +83,24 @@ def tidy_docstring(docstring: Optional[str]) -> str:
7683
docstring = re.sub("``([^`]*)``", r"'\1'", docstring)
7784

7885
return f"\n{docstring}"
86+
87+
88+
def mark_text_ranges(node: ast.AST, source: str):
89+
"""
90+
Node is an AST, source is corresponding source as string.
91+
Function adds recursively attributes end_lineno and end_col_offset to each node
92+
which has attributes lineno and col_offset.
93+
94+
:param node:
95+
:param source: The corresponding source code for the node.
96+
"""
97+
98+
ASTTokens(source, tree=node)
99+
100+
for child in ast.walk(node):
101+
if hasattr(child, "last_token"):
102+
child.end_lineno, child.end_col_offset = child.last_token.end # type: ignore
103+
104+
if hasattr(child, "lineno"):
105+
# Fixes problems with some nodes like binop
106+
child.lineno, child.col_offset = child.first_token.start # type: ignore

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
asttokens>=1.1
12
domdf_python_tools>=0.4.10
23
flake8>=3.7

tests/common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ async def a_function(): ...
7878
7979
8080
81+
async def a_function(): ...
82+
'''
83+
84+
testing_source_j = '''
85+
from tests.common import (
86+
mangled_source,
87+
results,
88+
testing_source_a,
89+
testing_source_b,
90+
testing_source_c,
91+
testing_source_d,
92+
testing_source_e,
93+
testing_source_f,
94+
testing_source_g,
95+
testing_source_h,
96+
testing_source_i
97+
)
98+
99+
81100
async def a_function(): ...
82101
'''
83102

tests/test_flake8_dunder_all.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# stdlib
22
import ast
33
import re
4+
import sys
45
from typing import List
56

67
# 3rd party
@@ -9,7 +10,7 @@
910
from domdf_python_tools.terminal_colours import Fore
1011

1112
# this package
12-
from flake8_dunder_all import Visitor, check_and_add_all
13+
from flake8_dunder_all import Visitor, check_and_add_all, mark_text_ranges
1314
from tests.common import (
1415
mangled_source,
1516
results,
@@ -21,7 +22,8 @@
2122
testing_source_f,
2223
testing_source_g,
2324
testing_source_h,
24-
testing_source_i
25+
testing_source_i,
26+
testing_source_j
2527
)
2628

2729

@@ -44,6 +46,7 @@
4446
),
4547
pytest.param(testing_source_h, set(), id="from import"),
4648
pytest.param(testing_source_i, {'0:0: DALL000 Module lacks __all__.'}, id="lots of lines"),
49+
pytest.param(testing_source_j, {'0:0: DALL000 Module lacks __all__.'}, id="multiline import"),
4750
]
4851
)
4952
def test_plugin(source, expects):
@@ -73,6 +76,7 @@ def test_plugin(source, expects):
7376
pytest.param(testing_source_g, ["a_function"], False, 3, id="async function no __all__"),
7477
pytest.param(testing_source_h, [], False, 1, id="from import"),
7578
pytest.param(testing_source_i, ["a_function"], False, 3, id="lots of lines"),
79+
pytest.param(testing_source_j, ["a_function"], False, 2, id="multiline import"),
7680
]
7781
)
7882
def test_visitor(source, members, found_all, last_import):
@@ -84,6 +88,43 @@ def test_visitor(source, members, found_all, last_import):
8488
assert visitor.last_import is last_import
8589

8690

91+
@pytest.mark.parametrize(
92+
"source, members, found_all, last_import",
93+
[
94+
pytest.param('import foo', [], False, 1, id="just an import"),
95+
pytest.param('"""a docstring"""', [], False, 0, id="just a docstring"),
96+
pytest.param(testing_source_a, [], False, 3, id="import and docstring"),
97+
pytest.param(testing_source_b, ['a_function'], False, 3, id="function no __all__"),
98+
pytest.param(testing_source_c, ['Foo'], False, 3, id="class no __all__"),
99+
pytest.param(
100+
testing_source_d, ['Foo', 'a_function'], False, 3, id="function and class no __all__"
101+
),
102+
pytest.param(
103+
testing_source_e, ['Foo', 'a_function'], True, 3, id="function and class with __all__"
104+
),
105+
pytest.param(
106+
testing_source_f, ['Foo', 'a_function'],
107+
True,
108+
3,
109+
id="function and class with __all__ and extra variable"
110+
),
111+
pytest.param(testing_source_g, ["a_function"], False, 3, id="async function no __all__"),
112+
pytest.param(testing_source_h, [], False, 1, id="from import"),
113+
pytest.param(testing_source_i, ["a_function"], False, 3, id="lots of lines"),
114+
pytest.param(testing_source_j, ["a_function"], False, 14, id="multiline import"),
115+
]
116+
)
117+
def test_visitor_endlineno(source, members, found_all, last_import):
118+
visitor = Visitor(True)
119+
tree = ast.parse(source)
120+
mark_text_ranges(tree, source)
121+
visitor.visit(tree)
122+
123+
assert visitor.members == members
124+
assert visitor.found_all is found_all
125+
assert visitor.last_import is last_import
126+
127+
87128
@pytest.mark.parametrize(
88129
"source, members, ret",
89130
[
@@ -105,7 +146,6 @@ def test_visitor(source, members, found_all, last_import):
105146
]
106147
)
107148
def test_check_and_add_all(tmpdir, source, members: List[str], ret):
108-
109149
tmpfile = PathPlus(tmpdir) / "source.py"
110150
tmpfile.write_text(source)
111151

@@ -116,6 +156,7 @@ def test_check_and_add_all(tmpdir, source, members: List[str], ret):
116156
assert f"__all__ = [{members_string}]" in tmpfile.read_text()
117157

118158

159+
@pytest.mark.skipif(condition=not (sys.version_info < (3, 8)), reason="Not required after python 3.8")
119160
@pytest.mark.parametrize(
120161
"source, members, ret",
121162
[
@@ -131,7 +172,6 @@ def test_check_and_add_all(tmpdir, source, members: List[str], ret):
131172
]
132173
)
133174
def test_check_and_add_all_single_quotes(tmpdir, source, members: List[str], ret):
134-
135175
tmpfile = PathPlus(tmpdir) / "source.py"
136176
tmpfile.write_text(source)
137177

@@ -146,7 +186,6 @@ def test_check_and_add_all_single_quotes(tmpdir, source, members: List[str], ret
146186
pytest.param(mangled_source, [], id="mangled"),
147187
])
148188
def test_check_and_add_all_mangled(tmpdir, capsys, source, members):
149-
150189
tmpfile = PathPlus(tmpdir) / "source.py"
151190
tmpfile.write_text(source)
152191
assert check_and_add_all(tmpfile) == 4

tests/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import ast
33

44
# this package
5-
from flake8_dunder_all import get_docstring_lineno
6-
from flake8_dunder_all.__main__ import tidy_docstring
5+
from flake8_dunder_all.utils import get_docstring_lineno, tidy_docstring
76

87

98
def test_get_docstring_lineno():

0 commit comments

Comments
 (0)