Skip to content

Commit f9d8931

Browse files
committed
Type Stage 1 status snapshots
1 parent 04bc94b commit f9d8931

5 files changed

Lines changed: 305 additions & 65 deletions

File tree

policyengine_us_data/build_datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
from .status import Stage1ErrorRecord, Stage1StatusEvent
4444
from .status_store import (
4545
Stage1StatusRecorder,
46+
Stage1StatusReadError,
4647
Stage1StatusSnapshot,
48+
Stage1StoredStatusEvent,
4749
empty_stage_1_status_snapshot,
4850
read_stage_1_status_snapshot,
4951
)
@@ -69,9 +71,11 @@
6971
"Stage1Coordinator",
7072
"Stage1ErrorRecord",
7173
"Stage1StatusRecorder",
74+
"Stage1StatusReadError",
7275
"Stage1StatusEvent",
7376
"Stage1StatusSink",
7477
"Stage1StatusSnapshot",
78+
"Stage1StoredStatusEvent",
7579
"Stage1SubstepRunner",
7680
"SubprocessLogCapture",
7781
"TargetDatabaseSchemaSummaryWriter",

policyengine_us_data/build_datasets/results.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Mapping
5+
from collections.abc import Mapping, Sequence
66
from dataclasses import dataclass, field
7-
from typing import Any, Literal
7+
from typing import Any, Literal, cast
88

99
from .status import Stage1ErrorRecord, Stage1SubstepStatus
1010

@@ -27,6 +27,24 @@ class DatasetCommandResult:
2727
error: Stage1ErrorRecord | None = None
2828
metadata: Mapping[str, Any] = field(default_factory=dict)
2929

30+
@classmethod
31+
def from_dict(cls, data: Mapping[str, Any]) -> "DatasetCommandResult":
32+
"""Build a command result from a JSON-compatible payload."""
33+
34+
error = _error_record_from_payload(data.get("error"))
35+
return cls(
36+
command_name=str(data["command_name"]),
37+
argv=_string_tuple(data.get("argv", ())),
38+
status=_command_execution_status(data["status"]),
39+
returncode=_optional_int(data.get("returncode")),
40+
started_at=str(data["started_at"]),
41+
completed_at=str(data["completed_at"]),
42+
duration_s=float(data["duration_s"]),
43+
combined_output_tail=_string_tuple(data.get("combined_output_tail", ())),
44+
error=error,
45+
metadata=_metadata_mapping(data.get("metadata", {})),
46+
)
47+
3048
def to_dict(self) -> dict[str, Any]:
3149
"""Return a JSON-compatible command result payload."""
3250

@@ -60,6 +78,32 @@ class DatasetSubstepResult:
6078
error: Stage1ErrorRecord | None = None
6179
metadata: Mapping[str, Any] = field(default_factory=dict)
6280

81+
@classmethod
82+
def from_dict(cls, data: Mapping[str, Any]) -> "DatasetSubstepResult":
83+
"""Build a substep result from a JSON-compatible payload."""
84+
85+
command_results = data.get("command_results", ())
86+
if not isinstance(command_results, Sequence) or isinstance(
87+
command_results, str
88+
):
89+
raise TypeError("command_results must be a sequence")
90+
return cls(
91+
substep_id=str(data["substep_id"]),
92+
title=str(data["title"]),
93+
status=_stage_1_substep_status(data["status"]),
94+
started_at=_optional_str(data.get("started_at")),
95+
completed_at=str(data["completed_at"]),
96+
duration_s=_optional_float(data.get("duration_s")),
97+
command_names=_string_tuple(data.get("command_names", ())),
98+
command_results=tuple(
99+
DatasetCommandResult.from_dict(_mapping_payload(result))
100+
for result in command_results
101+
),
102+
artifact_paths=_string_tuple(data.get("artifact_paths", ())),
103+
error=_error_record_from_payload(data.get("error")),
104+
metadata=_metadata_mapping(data.get("metadata", {})),
105+
)
106+
63107
def to_dict(self) -> dict[str, Any]:
64108
"""Return a JSON-compatible substep result payload."""
65109

@@ -78,6 +122,54 @@ def to_dict(self) -> dict[str, Any]:
78122
}
79123

80124

125+
def _command_execution_status(value: Any) -> CommandExecutionStatus:
126+
if value in ("completed", "failed"):
127+
return cast(CommandExecutionStatus, value)
128+
raise ValueError(f"Invalid command execution status: {value!r}")
129+
130+
131+
def _stage_1_substep_status(value: Any) -> Stage1SubstepStatus:
132+
if value in ("started", "completed", "skipped", "failed"):
133+
return cast(Stage1SubstepStatus, value)
134+
raise ValueError(f"Invalid Stage 1 substep status: {value!r}")
135+
136+
137+
def _string_tuple(value: Any) -> tuple[str, ...]:
138+
if not isinstance(value, Sequence) or isinstance(value, str):
139+
raise TypeError("Expected a sequence")
140+
return tuple(str(item) for item in value)
141+
142+
143+
def _mapping_payload(value: Any) -> Mapping[str, Any]:
144+
if not isinstance(value, Mapping):
145+
raise TypeError("Expected a mapping")
146+
return value
147+
148+
149+
def _metadata_mapping(value: Any) -> Mapping[str, Any]:
150+
if not isinstance(value, Mapping):
151+
raise TypeError("metadata must be a mapping")
152+
return dict(value)
153+
154+
155+
def _error_record_from_payload(value: Any) -> Stage1ErrorRecord | None:
156+
if value is None:
157+
return None
158+
return Stage1ErrorRecord.from_dict(_mapping_payload(value))
159+
160+
161+
def _optional_str(value: Any) -> str | None:
162+
return None if value is None else str(value)
163+
164+
165+
def _optional_int(value: Any) -> int | None:
166+
return None if value is None else int(value)
167+
168+
169+
def _optional_float(value: Any) -> float | None:
170+
return None if value is None else float(value)
171+
172+
81173
__all__ = [
82174
"CommandExecutionStatus",
83175
"DatasetCommandResult",

policyengine_us_data/build_datasets/status.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Mapping
66
from dataclasses import dataclass, field, replace
77
from datetime import datetime, timezone
8-
from typing import TYPE_CHECKING, Any, Literal
8+
from typing import TYPE_CHECKING, Any, Literal, cast
99

1010
if TYPE_CHECKING:
1111
from modal_app.step_manifests.errors import PipelineErrorRecord
@@ -37,6 +37,19 @@ class Stage1StatusEvent:
3737
command_name: str | None = None
3838
metadata: Mapping[str, Any] = field(default_factory=dict)
3939

40+
@classmethod
41+
def from_dict(cls, data: Mapping[str, Any]) -> "Stage1StatusEvent":
42+
"""Build a status event from a JSON-compatible payload."""
43+
44+
return cls(
45+
substep_id=str(data["substep_id"]),
46+
status=_stage_1_substep_status(data["status"]),
47+
created_at=str(data["created_at"]),
48+
message=_optional_str(data.get("message")),
49+
command_name=_optional_str(data.get("command_name")),
50+
metadata=_metadata_mapping(data.get("metadata", {})),
51+
)
52+
4053
def to_dict(self) -> dict[str, Any]:
4154
"""Return a JSON-compatible status event payload."""
4255

@@ -67,6 +80,20 @@ class Stage1ErrorRecord:
6780
created_at: str = field(default_factory=utc_timestamp)
6881
metadata: Mapping[str, Any] = field(default_factory=dict)
6982

83+
@classmethod
84+
def from_dict(cls, data: Mapping[str, Any]) -> "Stage1ErrorRecord":
85+
"""Build an error record from a JSON-compatible payload."""
86+
87+
return cls(
88+
substep_id=_optional_str(data.get("substep_id")),
89+
command_name=_optional_str(data.get("command_name")),
90+
error_type=str(data["error_type"]),
91+
message=str(data["message"]),
92+
returncode=_optional_int(data.get("returncode")),
93+
created_at=str(data.get("created_at") or utc_timestamp()),
94+
metadata=_metadata_mapping(data.get("metadata", {})),
95+
)
96+
7097
@classmethod
7198
def from_exception(
7299
cls,
@@ -159,6 +186,26 @@ def _pipeline_traceback_text(error: Stage1ErrorRecord) -> str:
159186
return "\n\n".join(parts)
160187

161188

189+
def _stage_1_substep_status(value: Any) -> Stage1SubstepStatus:
190+
if value in ("started", "completed", "skipped", "failed"):
191+
return cast(Stage1SubstepStatus, value)
192+
raise ValueError(f"Invalid Stage 1 substep status: {value!r}")
193+
194+
195+
def _optional_str(value: Any) -> str | None:
196+
return None if value is None else str(value)
197+
198+
199+
def _optional_int(value: Any) -> int | None:
200+
return None if value is None else int(value)
201+
202+
203+
def _metadata_mapping(value: Any) -> Mapping[str, Any]:
204+
if not isinstance(value, Mapping):
205+
raise TypeError("metadata must be a mapping")
206+
return dict(value)
207+
208+
162209
__all__ = [
163210
"Stage1ErrorRecord",
164211
"Stage1StatusEvent",

0 commit comments

Comments
 (0)