Skip to content

Commit f10b306

Browse files
committed
address feedback
1 parent 8c286c0 commit f10b306

2 files changed

Lines changed: 36 additions & 35 deletions

File tree

scripts/postprocess_generated_models.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,17 @@
3636
'pricingModel': 'pricing_model',
3737
}
3838

39-
# Seed models for the TypedDict pruning. Every TypedDict in `_typeddicts.py` that is not
40-
# transitively reachable from this set is removed. Keep in sync with the `dict | <Model>` unions
41-
# on resource-client method signatures.
42-
TYPEDDICT_SEEDS: frozenset[str] = frozenset(
39+
# TypedDicts accepted as inputs by resource-client methods. These are the roots of the reachability
40+
# walk over `_typeddicts_generated.py`: anything not reachable from here (directly or transitively)
41+
# is dropped so only the TypedDicts that are part of the public input surface — plus their nested
42+
# shapes — survive. Names are the raw datamodel-codegen outputs (no `Dict` suffix yet); the suffix
43+
# is added later by `rename_with_dict_suffix`. Update this set whenever a new `<Name>Dict | <Name>`
44+
# union is introduced on a resource-client method signature.
45+
RESOURCE_INPUT_TYPEDDICTS: frozenset[str] = frozenset(
4346
{
44-
'Request',
45-
'TaskInput',
46-
'WebhookCreate',
47+
'Request', # RequestQueueClient.update_request
48+
'TaskInput', # Actor/Task start/call/update default input
49+
'WebhookCreate', # Actor/Task start/call webhook list element
4750
}
4851
)
4952

@@ -191,17 +194,17 @@ def _collect_name_references(node: ast.AST, exclude: set[str]) -> set[str]:
191194
return refs - exclude
192195

193196

194-
def _compute_transitive_closure(deps: dict[str, set[str]], seeds: set[str]) -> set[str]:
195-
"""Return every symbol transitively reachable from any seed."""
196-
closure: set[str] = set()
197+
def _compute_reachable_symbols(deps: dict[str, set[str]], seeds: set[str]) -> set[str]:
198+
"""Return every symbol transitively reachable from any seed via `deps`."""
199+
reachable: set[str] = set()
197200
stack = [s for s in seeds if s in deps]
198201
while stack:
199202
name = stack.pop()
200-
if name in closure:
203+
if name in reachable:
201204
continue
202-
closure.add(name)
203-
stack.extend(ref for ref in deps[name] if ref in deps and ref not in closure)
204-
return closure
205+
reachable.add(name)
206+
stack.extend(ref for ref in deps[name] if ref in deps and ref not in reachable)
207+
return reachable
205208

206209

207210
def prune_typeddicts(content: str, seeds: frozenset[str]) -> tuple[str, set[str]]:
@@ -218,7 +221,7 @@ def prune_typeddicts(content: str, seeds: frozenset[str]) -> tuple[str, set[str]
218221
# Ignore builtins and imported names — we only care about cross-references within the file.
219222
deps[name] = _collect_name_references(node, exclude={name}) & symbol_names
220223

221-
kept = _compute_transitive_closure(deps, set(seeds))
224+
kept = _compute_reachable_symbols(deps, set(seeds))
222225

223226
missing_seeds = seeds - symbol_names
224227
if missing_seeds:
@@ -265,7 +268,7 @@ def postprocess_models(path: Path) -> bool:
265268
def postprocess_typeddicts(path: Path) -> bool:
266269
"""Apply `_typeddicts_generated.py`-specific fixes. Returns True if the file changed."""
267270
original = path.read_text()
268-
pruned, kept = prune_typeddicts(original, TYPEDDICT_SEEDS)
271+
pruned, kept = prune_typeddicts(original, RESOURCE_INPUT_TYPEDDICTS)
269272
renamed = rename_with_dict_suffix(pruned, kept)
270273
flattened = flatten_empty_typeddicts(renamed)
271274
final = add_docs_group_decorators(flattened, 'Typed dicts')

tests/unit/test_utils.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import io
24
from datetime import timedelta
35
from http import HTTPStatus
4-
from typing import TYPE_CHECKING, cast
6+
from typing import TYPE_CHECKING
57
from unittest.mock import Mock
68

79
import impit
@@ -57,24 +59,20 @@ def test_webhook_representation_list_to_base64() -> None:
5759

5860

5961
def test_webhook_representation_list_from_dicts() -> None:
60-
"""Test that from_webhooks accepts plain dicts with the minimal ad-hoc webhook shape."""
61-
# The runtime only needs the keys consumed by WebhookRepresentation (event_types, request_url,
62-
# optionally payload_template/headers_template). WebhookCreateDict requires more — cast to
63-
# silence the type checker while still exercising the minimal-dict runtime path.
64-
webhooks = cast(
65-
'list[WebhookCreateDict]',
66-
[
67-
{
68-
'event_types': ['ACTOR.RUN.CREATED'],
69-
'request_url': 'https://example.com/run-created',
70-
},
71-
{
72-
'event_types': ['ACTOR.RUN.SUCCEEDED'],
73-
'request_url': 'https://example.com/run-succeeded',
74-
'payload_template': '{"hello": "world"}',
75-
},
76-
],
77-
)
62+
"""Test that from_webhooks accepts plain dicts typed as WebhookCreateDict."""
63+
webhooks: list[WebhookCreateDict] = [
64+
{
65+
'event_types': ['ACTOR.RUN.CREATED'],
66+
'condition': {},
67+
'request_url': 'https://example.com/run-created',
68+
},
69+
{
70+
'event_types': ['ACTOR.RUN.SUCCEEDED'],
71+
'condition': {},
72+
'request_url': 'https://example.com/run-succeeded',
73+
'payload_template': '{"hello": "world"}',
74+
},
75+
]
7876
result = WebhookRepresentationList.from_webhooks(webhooks).to_base64()
7977

8078
assert result is not None

0 commit comments

Comments
 (0)