Skip to content

Commit 89bfb88

Browse files
vdusekclaude
andcommitted
feat: generate Exception subclass per API error type
Addresses #423. Each `ErrorType` enum member from the OpenAPI spec now has a matching `ApifyApiError` subclass (e.g. `RecordNotFoundError`) generated into `_generated_errors.py`. `ApifyApiError.__new__` dispatches to the right subclass based on the response's `error.type`, so `except ApifyApiError` keeps working while `except RecordNotFoundError` becomes possible. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6c3b84d commit 89bfb88

6 files changed

Lines changed: 3091 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ indent-style = "space"
158158
"ERA001", # Commented-out code
159159
"TC003", # Move standard library import into a type-checking block
160160
]
161+
"src/apify_client/_generated_errors.py" = [
162+
"E501", # Line too long (long error-type keys pushing dict entries over the limit)
163+
]
161164

162165
[tool.ruff.lint.flake8-quotes]
163166
docstring-quotes = "double"

scripts/postprocess_generated_models.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,20 @@ class alongside the canonical `ErrorType(StrEnum)`. This script removes the dupl
1010
rewires references to use `ErrorType`.
1111
- Missing @docs_group decorator: Adds `@docs_group('Models')` to all model classes for API
1212
reference documentation grouping, along with the required import.
13+
14+
Also generates `_generated_errors.py` — one `ApifyApiError` subclass per `ErrorType` enum member
15+
plus a dispatch map used by `ApifyApiError.__new__` to return the specific subclass.
1316
"""
1417

1518
from __future__ import annotations
1619

20+
import ast
21+
import builtins
1722
import re
1823
from pathlib import Path
1924

2025
MODELS_PATH = Path(__file__).resolve().parent.parent / 'src' / 'apify_client' / '_models.py'
26+
GENERATED_ERRORS_PATH = Path(__file__).resolve().parent.parent / 'src' / 'apify_client' / '_generated_errors.py'
2127
DOCS_GROUP_DECORATOR = "@docs_group('Models')"
2228

2329
# Map of camelCase discriminator values to their snake_case equivalents.
@@ -76,6 +82,130 @@ def add_docs_group_decorators(content: str) -> str:
7682
return '\n'.join(result)
7783

7884

85+
def extract_error_type_members(content: str) -> list[tuple[str, str]]:
86+
"""Parse `_models.py` and return `(member_name, member_value)` tuples for the `ErrorType` enum.
87+
88+
Uses AST parsing for robustness against formatting differences. Returns an empty list if the
89+
`ErrorType` class is not found.
90+
"""
91+
tree = ast.parse(content)
92+
for node in ast.walk(tree):
93+
if isinstance(node, ast.ClassDef) and node.name == 'ErrorType':
94+
return [
95+
(stmt.targets[0].id, stmt.value.value)
96+
for stmt in node.body
97+
if (
98+
isinstance(stmt, ast.Assign)
99+
and len(stmt.targets) == 1
100+
and isinstance(stmt.targets[0], ast.Name)
101+
and isinstance(stmt.value, ast.Constant)
102+
and isinstance(stmt.value.value, str)
103+
)
104+
]
105+
return []
106+
107+
108+
def _pascal_case(name: str) -> str:
109+
"""Convert `SCREAMING_SNAKE_CASE` to `PascalCase`, preserving all-caps parts that contain digits.
110+
111+
Parts like `3D` or `X402` are left as-is so the result reads naturally (e.g.
112+
`FIELD_3D_SECURE` → `Field3DSecure` rather than `Field3dSecure`).
113+
"""
114+
return ''.join(part if any(c.isdigit() for c in part) else part.capitalize() for part in name.split('_'))
115+
116+
117+
def derive_exception_class_names(members: list[tuple[str, str]]) -> list[tuple[str, str, str]]:
118+
"""Derive unique Exception class names for each `ErrorType` enum member.
119+
120+
Strategy: strip a trailing `_ERROR` from the enum name and PascalCase the result, then append
121+
`Error`. If that collides with a previously derived name, always append `Error` to the full
122+
enum name — so `SCHEMA_VALIDATION` → `SchemaValidationError` (first wins) and
123+
`SCHEMA_VALIDATION_ERROR` falls back to `SchemaValidationErrorError`.
124+
125+
Returns a list of `(member_name, member_value, class_name)` tuples.
126+
"""
127+
taken: set[str] = set()
128+
builtin_names = set(dir(builtins))
129+
result: list[tuple[str, str, str]] = []
130+
for member_name, member_value in members:
131+
stripped = member_name.removesuffix('_ERROR')
132+
candidate = _pascal_case(stripped) + 'Error'
133+
if candidate in taken:
134+
candidate = _pascal_case(member_name) + 'Error'
135+
# Avoid shadowing builtins like `NotImplementedError` or `TimeoutError`.
136+
if candidate in builtin_names:
137+
candidate = 'Api' + candidate
138+
if candidate in taken:
139+
raise RuntimeError(
140+
f'Cannot derive a unique Exception class name for ErrorType.{member_name} '
141+
f'(value={member_value!r}); collides with an existing class. '
142+
'Extend derive_exception_class_names to handle this case.'
143+
)
144+
taken.add(candidate)
145+
result.append((member_name, member_value, candidate))
146+
return result
147+
148+
149+
def render_generated_errors_module(classes: list[tuple[str, str, str]]) -> str:
150+
"""Render the full `_generated_errors.py` source from the derived class list."""
151+
lines: list[str] = [
152+
'# generated by scripts/postprocess_generated_models.py -- do not edit manually',
153+
'"""Auto-generated Exception subclasses, one per `ErrorType` enum member.',
154+
'',
155+
'Each subclass inherits from `ApifyApiError` so existing `except ApifyApiError` handlers',
156+
'keep working. `ApifyApiError.__new__` uses `API_ERROR_CLASS_BY_TYPE` to dispatch to the',
157+
'specific subclass based on the `type` field of the API error response.',
158+
'"""',
159+
'',
160+
'from __future__ import annotations',
161+
'',
162+
'from apify_client._docs import docs_group',
163+
'from apify_client.errors import ApifyApiError',
164+
'',
165+
]
166+
167+
for _member_name, member_value, class_name in classes:
168+
lines.extend(
169+
[
170+
'',
171+
"@docs_group('Errors')",
172+
f'class {class_name}(ApifyApiError):',
173+
f' """Raised when the Apify API returns a `{member_value}` error."""',
174+
'',
175+
]
176+
)
177+
178+
lines.extend(
179+
[
180+
'',
181+
'API_ERROR_CLASS_BY_TYPE: dict[str, type[ApifyApiError]] = {',
182+
*(f" '{member_value}': {class_name}," for _, member_value, class_name in classes),
183+
'}',
184+
'',
185+
'',
186+
'__all__ = [',
187+
*(f" '{name}'," for name in sorted(['API_ERROR_CLASS_BY_TYPE', *[c for _, _, c in classes]])),
188+
']',
189+
'',
190+
]
191+
)
192+
return '\n'.join(lines)
193+
194+
195+
def write_generated_errors_module(content: str) -> bool:
196+
"""Derive and write `_generated_errors.py`. Returns True if the file changed."""
197+
members = extract_error_type_members(content)
198+
if not members:
199+
return False
200+
classes = derive_exception_class_names(members)
201+
rendered = render_generated_errors_module(classes)
202+
previous = GENERATED_ERRORS_PATH.read_text() if GENERATED_ERRORS_PATH.exists() else ''
203+
if rendered != previous:
204+
GENERATED_ERRORS_PATH.write_text(rendered)
205+
return True
206+
return False
207+
208+
79209
def main() -> None:
80210
content = MODELS_PATH.read_text()
81211
fixed = fix_discriminators(content)
@@ -88,6 +218,11 @@ def main() -> None:
88218
else:
89219
print('No fixes needed')
90220

221+
if write_generated_errors_module(fixed):
222+
print(f'Regenerated error classes in {GENERATED_ERRORS_PATH}')
223+
else:
224+
print('No error-class regeneration needed')
225+
91226

92227
if __name__ == '__main__':
93228
main()

0 commit comments

Comments
 (0)