11#!/usr/bin/env python
22
3- import ast
4- from collections .abc import Generator
5- from pathlib import Path
3+ """Fix async client docstrings to match their sync counterparts."""
64
7- from griffe import Module , load
5+ from __future__ import annotations
86
9- from ._utils import SKIPPED_METHODS , sync_to_async_docstring
7+ import ast
8+ from pathlib import Path
9+ from typing import TYPE_CHECKING
1010
11- # Load the apify_client package
12- src_path = Path (__file__ ).parent .resolve () / '../src'
13- package = load ('apify_client' , search_paths = [str (src_path )])
11+ from ._utils import SKIPPED_METHODS , load_package , sync_to_async_docstring , walk_modules
1412
13+ if TYPE_CHECKING :
14+ from griffe import Class , Module
1515
16- def walk_modules (module : Module ) -> Generator [Module ]:
17- """Recursively yield all modules in the package."""
18- yield module
19- for submodule in module .modules .values ():
20- yield from walk_modules (submodule )
16+ Replacement = tuple [str , str , str , bool ]
17+ EditOp = tuple [str , int , int | None , str ]
2118
2219
2320def format_docstring (content : str , indent : str ) -> str :
@@ -71,57 +68,37 @@ def find_method_body_start(tree: ast.AST, class_name: str, method_name: str) ->
7168 return None
7269
7370
74- # Go through every module in the package
75- if not isinstance (package , Module ):
76- raise TypeError ('Expected griffe to load a Module' )
77- for module in walk_modules (package ):
78- replacements = []
71+ def collect_replacements (sync_class : Class , async_class : Class ) -> list [Replacement ]:
72+ """Collect docstring replacements needed for an async class."""
73+ replacements : list [Replacement ] = []
7974
80- for async_class in module . classes .values ():
81- if not async_class . name . endswith ( 'ClientAsync' ):
75+ for async_method in async_class . functions .values ():
76+ if any ( str ( d . value ) == 'ignore_docs' for d in async_method . decorators ):
8277 continue
8378
84- # Find the corresponding sync class (same name, but without Async)
85- sync_class = module .classes .get (async_class .name .replace ('ClientAsync' , 'Client' ))
86- if not sync_class :
79+ if async_method .name in SKIPPED_METHODS :
8780 continue
8881
89- # Go through all methods in the async class
90- for async_method in async_class .functions .values ():
91- # Skip methods with @ignore_docs decorator
92- if any (str (d .value ) == 'ignore_docs' for d in async_method .decorators ):
93- continue
82+ sync_method = sync_class .functions .get (async_method .name )
83+ if not sync_method or not sync_method .docstring :
84+ continue
9485
95- # Skip methods whose docstrings are intentionally different
96- if async_method .name in SKIPPED_METHODS :
97- continue
86+ correct_docstring = sync_to_async_docstring (sync_method .docstring .value )
9887
99- # Find corresponding sync method in the sync class
100- sync_method = sync_class .functions .get (async_method .name )
101- if not sync_method or not sync_method .docstring :
102- continue
88+ if not async_method .docstring :
89+ print (f' Adding missing docstring for "{ async_class .name } .{ async_method .name } "' )
90+ replacements .append ((async_class .name , async_method .name , correct_docstring , False ))
91+ elif async_method .docstring .value != correct_docstring :
92+ print (f' Updating docstring for "{ async_class .name } .{ async_method .name } "' )
93+ replacements .append ((async_class .name , async_method .name , correct_docstring , True ))
10394
104- correct_docstring = sync_to_async_docstring ( sync_method . docstring . value )
95+ return replacements
10596
106- if not async_method .docstring :
107- print (f'Fixing missing docstring for "{ async_class .name } .{ async_method .name } "...' )
108- replacements .append ((async_class .name , async_method .name , correct_docstring , False ))
109- elif async_method .docstring .value != correct_docstring :
110- replacements .append ((async_class .name , async_method .name , correct_docstring , True ))
11197
112- if not replacements :
113- continue
98+ def build_edit_ops (tree : ast .AST , replacements : list [Replacement ]) -> list [EditOp ]:
99+ """Build a list of edit operations from the collected replacements."""
100+ ops : list [EditOp ] = []
114101
115- # Read the source file and parse with ast for precise locations
116- filepath = module .filepath
117- if not isinstance (filepath , Path ):
118- continue
119- source = filepath .read_text (encoding = 'utf-8' )
120- source_lines = source .splitlines (keepends = True )
121- tree = ast .parse (source )
122-
123- # Collect replacement operations with line numbers
124- ops = []
125102 for class_name , method_name , correct_docstring , has_existing in replacements :
126103 if has_existing :
127104 result = find_docstring_range (tree , class_name , method_name )
@@ -140,7 +117,11 @@ def find_method_body_start(tree: ast.AST, class_name: str, method_name: str) ->
140117 formatted = format_docstring (correct_docstring , indent )
141118 ops .append (('insert' , insert_line , None , formatted ))
142119
143- # Sort by start line descending (process bottom-up to preserve line numbers)
120+ return ops
121+
122+
123+ def apply_edit_ops (source_lines : list [str ], ops : list [EditOp ]) -> list [str ]:
124+ """Apply edit operations to source lines (bottom-up to preserve line numbers)."""
144125 ops .sort (key = lambda x : x [1 ], reverse = True )
145126
146127 for op_type , start_line , end_line , formatted in ops :
@@ -150,5 +131,51 @@ def find_method_body_start(tree: ast.AST, class_name: str, method_name: str) ->
150131 elif op_type == 'insert' :
151132 source_lines [start_line - 1 : start_line - 1 ] = formatted_lines
152133
153- # Save the updated source code back to the file
134+ return source_lines
135+
136+
137+ def fix_module (module : Module ) -> int :
138+ """Fix docstrings in a single module. Returns the number of fixes applied."""
139+ replacements : list [Replacement ] = []
140+
141+ for async_class in module .classes .values ():
142+ if not async_class .name .endswith ('ClientAsync' ):
143+ continue
144+
145+ sync_class = module .classes .get (async_class .name .replace ('ClientAsync' , 'Client' ))
146+ if not sync_class :
147+ continue
148+
149+ replacements .extend (collect_replacements (sync_class , async_class ))
150+
151+ if not replacements :
152+ return 0
153+
154+ filepath = module .filepath
155+ if not isinstance (filepath , Path ):
156+ return 0
157+
158+ source = filepath .read_text (encoding = 'utf-8' )
159+ source_lines = source .splitlines (keepends = True )
160+ tree = ast .parse (source )
161+
162+ ops = build_edit_ops (tree , replacements )
163+ source_lines = apply_edit_ops (source_lines , ops )
164+
154165 filepath .write_text ('' .join (source_lines ), encoding = 'utf-8' )
166+ return len (ops )
167+
168+
169+ def main () -> None :
170+ """Fix all async client methods with missing or mismatched docstrings."""
171+ package = load_package ()
172+ fixed_count = sum (fix_module (module ) for module in walk_modules (package ))
173+
174+ if fixed_count :
175+ print (f'\n Fixed { fixed_count } docstring(s).' )
176+ else :
177+ print ('All async docstrings are already in sync.' )
178+
179+
180+ if __name__ == '__main__' :
181+ main ()
0 commit comments