@@ -10,14 +10,20 @@ class alongside the canonical `ErrorType(StrEnum)`. This script removes the dupl
1010 rewires references to use `ErrorType`.
1111- Missing @docs_group decorator: Adds `@docs_group('Models')` to all model classes for API
1212 reference documentation grouping, along with the required import.
13+
14+ Also generates `_generated_errors.py` — one `ApifyApiError` subclass per `ErrorType` enum member
15+ plus a dispatch map used by `ApifyApiError.__new__` to return the specific subclass.
1316"""
1417
1518from __future__ import annotations
1619
20+ import ast
21+ import builtins
1722import re
1823from pathlib import Path
1924
2025MODELS_PATH = Path (__file__ ).resolve ().parent .parent / 'src' / 'apify_client' / '_models.py'
26+ GENERATED_ERRORS_PATH = Path (__file__ ).resolve ().parent .parent / 'src' / 'apify_client' / '_generated_errors.py'
2127DOCS_GROUP_DECORATOR = "@docs_group('Models')"
2228
2329# Map of camelCase discriminator values to their snake_case equivalents.
@@ -76,6 +82,130 @@ def add_docs_group_decorators(content: str) -> str:
7682 return '\n ' .join (result )
7783
7884
85+ def extract_error_type_members (content : str ) -> list [tuple [str , str ]]:
86+ """Parse `_models.py` and return `(member_name, member_value)` tuples for the `ErrorType` enum.
87+
88+ Uses AST parsing for robustness against formatting differences. Returns an empty list if the
89+ `ErrorType` class is not found.
90+ """
91+ tree = ast .parse (content )
92+ for node in ast .walk (tree ):
93+ if isinstance (node , ast .ClassDef ) and node .name == 'ErrorType' :
94+ return [
95+ (stmt .targets [0 ].id , stmt .value .value )
96+ for stmt in node .body
97+ if (
98+ isinstance (stmt , ast .Assign )
99+ and len (stmt .targets ) == 1
100+ and isinstance (stmt .targets [0 ], ast .Name )
101+ and isinstance (stmt .value , ast .Constant )
102+ and isinstance (stmt .value .value , str )
103+ )
104+ ]
105+ return []
106+
107+
108+ def _pascal_case (name : str ) -> str :
109+ """Convert `SCREAMING_SNAKE_CASE` to `PascalCase`, preserving all-caps parts that contain digits.
110+
111+ Parts like `3D` or `X402` are left as-is so the result reads naturally (e.g.
112+ `FIELD_3D_SECURE` → `Field3DSecure` rather than `Field3dSecure`).
113+ """
114+ return '' .join (part if any (c .isdigit () for c in part ) else part .capitalize () for part in name .split ('_' ))
115+
116+
117+ def derive_exception_class_names (members : list [tuple [str , str ]]) -> list [tuple [str , str , str ]]:
118+ """Derive unique Exception class names for each `ErrorType` enum member.
119+
120+ Strategy: strip a trailing `_ERROR` from the enum name and PascalCase the result, then append
121+ `Error`. If that collides with a previously derived name, always append `Error` to the full
122+ enum name — so `SCHEMA_VALIDATION` → `SchemaValidationError` (first wins) and
123+ `SCHEMA_VALIDATION_ERROR` falls back to `SchemaValidationErrorError`.
124+
125+ Returns a list of `(member_name, member_value, class_name)` tuples.
126+ """
127+ taken : set [str ] = set ()
128+ builtin_names = set (dir (builtins ))
129+ result : list [tuple [str , str , str ]] = []
130+ for member_name , member_value in members :
131+ stripped = member_name .removesuffix ('_ERROR' )
132+ candidate = _pascal_case (stripped ) + 'Error'
133+ if candidate in taken :
134+ candidate = _pascal_case (member_name ) + 'Error'
135+ # Avoid shadowing builtins like `NotImplementedError` or `TimeoutError`.
136+ if candidate in builtin_names :
137+ candidate = 'Api' + candidate
138+ if candidate in taken :
139+ raise RuntimeError (
140+ f'Cannot derive a unique Exception class name for ErrorType.{ member_name } '
141+ f'(value={ member_value !r} ); collides with an existing class. '
142+ 'Extend derive_exception_class_names to handle this case.'
143+ )
144+ taken .add (candidate )
145+ result .append ((member_name , member_value , candidate ))
146+ return result
147+
148+
149+ def render_generated_errors_module (classes : list [tuple [str , str , str ]]) -> str :
150+ """Render the full `_generated_errors.py` source from the derived class list."""
151+ lines : list [str ] = [
152+ '# generated by scripts/postprocess_generated_models.py -- do not edit manually' ,
153+ '"""Auto-generated Exception subclasses, one per `ErrorType` enum member.' ,
154+ '' ,
155+ 'Each subclass inherits from `ApifyApiError` so existing `except ApifyApiError` handlers' ,
156+ 'keep working. `ApifyApiError.__new__` uses `API_ERROR_CLASS_BY_TYPE` to dispatch to the' ,
157+ 'specific subclass based on the `type` field of the API error response.' ,
158+ '"""' ,
159+ '' ,
160+ 'from __future__ import annotations' ,
161+ '' ,
162+ 'from apify_client._docs import docs_group' ,
163+ 'from apify_client.errors import ApifyApiError' ,
164+ '' ,
165+ ]
166+
167+ for _member_name , member_value , class_name in classes :
168+ lines .extend (
169+ [
170+ '' ,
171+ "@docs_group('Errors')" ,
172+ f'class { class_name } (ApifyApiError):' ,
173+ f' """Raised when the Apify API returns a `{ member_value } ` error."""' ,
174+ '' ,
175+ ]
176+ )
177+
178+ lines .extend (
179+ [
180+ '' ,
181+ 'API_ERROR_CLASS_BY_TYPE: dict[str, type[ApifyApiError]] = {' ,
182+ * (f" '{ member_value } ': { class_name } ," for _ , member_value , class_name in classes ),
183+ '}' ,
184+ '' ,
185+ '' ,
186+ '__all__ = [' ,
187+ * (f" '{ name } '," for name in sorted (['API_ERROR_CLASS_BY_TYPE' , * [c for _ , _ , c in classes ]])),
188+ ']' ,
189+ '' ,
190+ ]
191+ )
192+ return '\n ' .join (lines )
193+
194+
195+ def write_generated_errors_module (content : str ) -> bool :
196+ """Derive and write `_generated_errors.py`. Returns True if the file changed."""
197+ members = extract_error_type_members (content )
198+ if not members :
199+ return False
200+ classes = derive_exception_class_names (members )
201+ rendered = render_generated_errors_module (classes )
202+ previous = GENERATED_ERRORS_PATH .read_text () if GENERATED_ERRORS_PATH .exists () else ''
203+ if rendered != previous :
204+ GENERATED_ERRORS_PATH .write_text (rendered )
205+ return True
206+ return False
207+
208+
79209def main () -> None :
80210 content = MODELS_PATH .read_text ()
81211 fixed = fix_discriminators (content )
@@ -88,6 +218,11 @@ def main() -> None:
88218 else :
89219 print ('No fixes needed' )
90220
221+ if write_generated_errors_module (fixed ):
222+ print (f'Regenerated error classes in { GENERATED_ERRORS_PATH } ' )
223+ else :
224+ print ('No error-class regeneration needed' )
225+
91226
92227if __name__ == '__main__' :
93228 main ()
0 commit comments