-
Notifications
You must be signed in to change notification settings - Fork 63
Expand file tree
/
Copy pathdocstring.py
More file actions
386 lines (330 loc) · 13.4 KB
/
docstring.py
File metadata and controls
386 lines (330 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""Utility for retrieveing the docstring of a dataclass's attributes.
@author: Fabrice Normandin
"""
from __future__ import annotations
import functools
import inspect
# from inspect import
from dataclasses import dataclass
from logging import getLogger
import docstring_parser as dp
from docstring_parser.common import Docstring
dp_parse = functools.lru_cache(2048)(dp.parse)
inspect_getsource = functools.lru_cache(2048)(inspect.getsource)
inspect_getdoc = functools.lru_cache(2048)(inspect.getdoc)
logger = getLogger(__name__)
@dataclass
class AttributeDocString:
"""Simple dataclass for holding the comments of a given field."""
comment_above: str = ""
comment_inline: str = ""
docstring_below: str = ""
desc_from_cls_docstring: str = ""
"""The description of this field from the class docstring."""
@property
def help_string(self) -> str:
"""Returns the value that will be used for the "--help" string, using the contents of
self."""
return (
self.docstring_below
or self.comment_above
or self.comment_inline
or self.desc_from_cls_docstring
)
def get_attribute_docstring(
dataclass: type, field_name: str, accumulate_from_bases: bool = True
) -> AttributeDocString:
"""Returns the docstrings of a dataclass field.
NOTE: a docstring can either be:
- An inline comment, starting with <#>
- A Comment on the preceding line, starting with <#>
- A docstring on the following line, starting with either <\"\"\"> or <'''>
- The description of a field in the classes's docstring.
Arguments:
some_dataclass: a dataclass
field_name: the name of the field.
accumulate_from_bases: Whether to accumulate the docstring components by looking through the
base classes. When set to `False`, whenever one of the classes has a definition for the
field, it is directly returned. Otherwise, we accumulate the parts of the dodc
Returns:
AttributeDocString -- an object holding the string descriptions of the field.
"""
created_docstring: AttributeDocString | None = None
mro = inspect.getmro(dataclass)
assert mro[0] is dataclass
assert mro[-1] is object
mro = mro[:-1]
for base_class in mro:
attribute_docstring = _get_attribute_docstring(base_class, field_name)
if not attribute_docstring:
continue
if not created_docstring:
created_docstring = attribute_docstring
if not accumulate_from_bases:
# We found a definition for that field in that class, so return it directly.
return created_docstring
else:
# Update the fields.
created_docstring.comment_above = (
created_docstring.comment_above or attribute_docstring.comment_above
)
created_docstring.comment_inline = (
created_docstring.comment_inline or attribute_docstring.comment_inline
)
created_docstring.docstring_below = (
created_docstring.docstring_below or attribute_docstring.docstring_below
)
created_docstring.desc_from_cls_docstring = (
created_docstring.desc_from_cls_docstring
or attribute_docstring.desc_from_cls_docstring
)
if not created_docstring:
logger.debug(
RuntimeWarning(
f"Couldn't find the definition for field '{field_name}' within the dataclass "
f"{dataclass} or any of its base classes {','.join(t.__name__ for t in mro[1:])}."
)
)
return AttributeDocString()
return created_docstring
@functools.lru_cache(2048)
def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocString | None:
"""Gets the AttributeDocString of the given field in the given dataclass.
Doesn't inspect base classes.
"""
try:
source = inspect_getsource(dataclass)
except (TypeError, OSError) as e:
logger.debug(
UserWarning(
f"Couldn't retrieve the source code of class {dataclass} "
f"(in order to retrieve the docstring of field {field_name}): {e}"
)
)
return None
# Parse docstring to use as help strings
desc_from_cls_docstring = ""
cls_docstring = inspect_getdoc(dataclass)
if cls_docstring:
docstring: Docstring = dp_parse(cls_docstring)
for param in docstring.params:
if param.arg_name == field_name:
desc_from_cls_docstring = param.description or ""
# NOTE: We want to skip the docstring lines.
# NOTE: Currently, we just remove the __doc__ from the source. It's perhaps a bit crude,
# but it works.
if dataclass.__doc__ and dataclass.__doc__ in source:
source = source.replace(dataclass.__doc__, "\n", 1)
# note: does this remove the whitespace though?
code_lines: list[str] = source.splitlines()
# the first line is the class definition (OR the decorator!), we skip it.
start_line_index = 1
# starting at the second line, there might be the docstring for the class.
# We want to skip over that until we reach an attribute definition.
while start_line_index < len(code_lines):
if _contains_field_definition(code_lines[start_line_index]):
break
start_line_index += 1
lines_with_field_defs = [
(index, line) for index, line in enumerate(code_lines) if _contains_field_definition(line)
]
for i, line in lines_with_field_defs:
if _line_contains_definition_for(line, field_name):
# we found the line with the definition of this field.
comment_above = _get_comment_ending_at_line(code_lines, i - 1)
comment_inline = _get_inline_comment_at_line(code_lines, i)
docstring_below = _get_docstring_starting_at_line(code_lines, i + 1)
return AttributeDocString(
comment_above,
comment_inline,
docstring_below,
desc_from_cls_docstring=desc_from_cls_docstring,
)
return None
def _contains_field_definition(line: str) -> bool:
"""Returns whether or not a line contains a an dataclass field definition.
Arguments:
line_str {str} -- the line content
Returns:
bool -- True if there is an attribute definition in the line.
>>> _contains_field_definition("a: int = 0")
True
>>> _contains_field_definition("a: int")
True
>>> _contains_field_definition("a: int # comment")
True
>>> _contains_field_definition("a: int = 0 # comment")
True
>>> _contains_field_definition("class FooBaz(Foo, Baz):")
False
>>> _contains_field_definition("a = 4")
False
>>> _contains_field_definition("fooooooooobar.append(123)")
False
>>> _contains_field_definition("{a: int}")
False
>>> _contains_field_definition(" foobaz: int = 123 #: The foobaz property")
True
>>> _contains_field_definition("a #:= 3")
False
"""
# Get rid of any comments first.
line, _, _ = line.partition("#")
if ":" not in line:
return False
if "=" in line:
attribute_and_type, _, _ = line.partition("=")
else:
attribute_and_type = line
field_name, _, type = attribute_and_type.partition(":")
field_name = field_name.strip()
if ":" in type:
# weird annotation or dictionary?
return False
if not field_name:
# Empty attribute name?
return False
return field_name.isidentifier()
def _line_contains_definition_for(line: str, field_name: str) -> bool:
line = line.strip()
if not _contains_field_definition(line):
return False
attribute, _, type_and_value_assignment = line.partition(":")
attribute = attribute.strip() # remove any whitespace after the attribute name.
return attribute.isidentifier() and attribute == field_name
def _is_empty(line_str: str) -> bool:
return line_str.strip() == ""
def _is_comment(line_str: str) -> bool:
return line_str.strip().startswith("#")
def _get_comment_at_line(code_lines: list[str], line: int) -> str:
"""Gets the comment at line `line` in `code_lines`.
Arguments:
line {int} -- the index of the line in code_lines
Returns:
str -- the comment at the given line. empty string if not present.
"""
line_str = code_lines[line]
assert not _contains_field_definition(line_str)
if "#" not in line_str:
return ""
parts = line_str.split("#", maxsplit=1)
comment = parts[1].strip()
return comment
def _get_inline_comment_at_line(code_lines: list[str], line: int) -> str:
"""Gets the inline comment at line `line`.
Arguments:
line {int} -- the index of the line in code_lines
Returns:
str -- the inline comment at the given line, else an empty string.
"""
assert 0 <= line < len(code_lines)
assert _contains_field_definition(code_lines[line])
line_str = code_lines[line]
parts = line_str.split("#", maxsplit=1)
if len(parts) != 2:
return ""
comment = parts[1].strip()
return comment
def _get_comment_ending_at_line(code_lines: list[str], line: int) -> str:
start_line = line
end_line = line
# move up the code, one line at a time, while we don't hit the start,
# an attribute definition, or the end of a docstring.
while start_line > 0:
line_str = code_lines[start_line]
if _contains_field_definition(line_str):
break # previous line is an assignment
if '"""' in line_str or "'''" in line_str:
break # previous line has a docstring
start_line -= 1
start_line += 1
lines = []
for i in range(start_line, end_line + 1):
# print(f"line {i}: {code_lines[i]}")
if _is_empty(code_lines[i]):
continue
assert not _contains_field_definition(code_lines[i])
comment = _get_comment_at_line(code_lines, i)
lines.append(comment)
return "\n".join(lines).strip()
def _get_docstring_starting_at_line(code_lines: list[str], line: int) -> str:
i = line
token: str | None = None
triple_single = "'''"
triple_double = '"""'
# print("finding docstring starting from line", line)
# if we are looking further down than the end of the code, there is no
# docstring.
if line >= len(code_lines):
return ""
# the list of lines making up the docstring.
docstring_contents: list[str] = []
while i < len(code_lines):
line_str = code_lines[i]
# print(f"(docstring) line {line}: {line_str}")
# we haven't identified the starting line yet.
if token is None:
if _is_empty(line_str):
i += 1
continue
elif _contains_field_definition(line_str) or _is_comment(line_str):
# we haven't reached the start of a docstring yet (since token
# is None), and we reached a line with an attribute definition,
# or a comment, hence the docstring is empty.
return ""
elif triple_single in line_str and triple_double in line_str:
# This handles something stupid like:
# @dataclass
# class Bob:
# a: int
# """ hello '''
# bob
# ''' bye
# """
triple_single_index = line_str.index(triple_single)
triple_double_index = line_str.index(triple_double)
if triple_single_index < triple_double_index:
token = triple_single
else:
token = triple_double
elif triple_double in line_str:
token = triple_double
elif triple_single in line_str:
token = triple_single
else:
# for i, line in enumerate(code_lines):
# print(f"line {i}: <{line}>")
# print(f"token: <{token}>")
# print(line_str)
logger.debug(f"Warning: Unable to parse attribute docstring: {line_str}")
return ""
# get the string portion of the line (after a token or possibly
# between two tokens).
parts = line_str.split(token, maxsplit=2)
if len(parts) == 3:
# This takes care of cases like:
# @dataclass
# class Bob:
# a: int
# """ hello """
between_tokens = parts[1].strip()
# print("Between tokens:", between_tokens)
docstring_contents.append(between_tokens)
break
elif len(parts) == 2:
after_token = parts[1].strip()
# print("After token:", after_token)
docstring_contents.append(after_token)
else:
# print(f"token is <{token}>")
if token in line_str:
# print(f"Line {line} End of a docstring:", line_str)
before = line_str.split(token, maxsplit=1)[0]
docstring_contents.append(before.strip())
break
else:
# intermediate line without the token.
docstring_contents.append(line_str.strip())
i += 1
# print("Docstring contents:", docstring_contents)
return "\n".join(docstring_contents)