Skip to content

Commit 0d1a0c5

Browse files
Fix data models with compound fields by moving them to discriminated unions and forcing rebuilds on the models to fix a forward ref to AnyComponent
1 parent 9d47ae7 commit 0d1a0c5

9 files changed

Lines changed: 268 additions & 59 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "oshconnect"
3-
version = "0.5.1a9"
3+
version = "0.5.1a11"
44
description = "Library for interfacing with OSH, helping guide visualization efforts, and providing a place to store configurations. Implements OGC CS API Part 3 (Pub/Sub) MQTT topic conventions including :data topics and resource event topics."
55
readme = "README.md"
66
authors = [

src/oshconnect/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@
3333
QuantityRangeSchema,
3434
TimeRangeSchema,
3535
)
36-
from .schema_datamodels import SWEDatastreamRecordSchema, OMJSONDatastreamRecordSchema, JSONCommandSchema
36+
from .schema_datamodels import (
37+
SWEDatastreamRecordSchema,
38+
OMJSONDatastreamRecordSchema,
39+
SWEJSONCommandSchema,
40+
JSONCommandSchema,
41+
AnyDatastreamRecordSchema,
42+
AnyCommandSchema,
43+
)
3744

3845
# Event system
3946
from .events import EventHandler, IEventListener, CallbackListener, DefaultEventTypes, AtomicEventTypes, Event, EventBuilder
@@ -77,7 +84,10 @@
7784
"TimeRangeSchema",
7885
"SWEDatastreamRecordSchema",
7986
"OMJSONDatastreamRecordSchema",
87+
"SWEJSONCommandSchema",
8088
"JSONCommandSchema",
89+
"AnyDatastreamRecordSchema",
90+
"AnyCommandSchema",
8191
# Event system
8292
"EventHandler",
8393
"IEventListener",

src/oshconnect/resource_datamodels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import json
1010
from typing import List, TYPE_CHECKING
1111

12-
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
12+
from pydantic import BaseModel, ConfigDict, Field, model_validator
1313
from shapely import Point
1414

1515
from .api_utils import Link
1616
from .geometry import Geometry
17-
from .schema_datamodels import DatastreamRecordSchema, CommandSchema
17+
from .schema_datamodels import AnyCommandSchema, AnyDatastreamRecordSchema
1818
from .timemanagement import TimeInstant, TimePeriod
1919

2020
if TYPE_CHECKING:
@@ -227,7 +227,7 @@ class DatastreamResource(BaseModel):
227227
observed_properties: List[dict] = Field(default_factory=list, alias="observedProperties")
228228
system_id: str = Field(None, alias="system@id")
229229
links: List[Link] = Field(None)
230-
record_schema: SerializeAsAny[DatastreamRecordSchema] = Field(None, alias="schema")
230+
record_schema: AnyDatastreamRecordSchema = Field(None, alias="schema")
231231

232232
@classmethod
233233
@model_validator(mode="before")
@@ -371,7 +371,7 @@ class ControlStreamResource(BaseModel):
371371
execution_time: TimePeriod = Field(None, alias="executionTime")
372372
live: bool = Field(None)
373373
asynchronous: bool = Field(True, alias="async")
374-
command_schema: SerializeAsAny[CommandSchema] = Field(None, alias="schema")
374+
command_schema: AnyCommandSchema = Field(None, alias="schema")
375375
links: List[Link] = Field(None)
376376

377377
def to_csapi_dict(self) -> dict:

src/oshconnect/schema_datamodels.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from __future__ import annotations
88

99
from datetime import datetime
10-
from typing import Union, List, Literal
10+
from typing import Annotated, Union, List, Literal
1111

12-
from pydantic import BaseModel, Field, SerializeAsAny, field_validator, model_validator, HttpUrl, ConfigDict
12+
from pydantic import BaseModel, Field, model_validator, HttpUrl, ConfigDict
1313

1414
from .api_utils import Link, URI
15-
from .csapi4py.constants import ObservationFormat
16-
from .encoding import Encoding
15+
from .encoding import JSONEncoding
1716
from .geometry import Geometry
1817
from .swe_components import AnyComponent, check_named
1918
from .timemanagement import TimeInstant
@@ -76,8 +75,16 @@ class SWEJSONCommandSchema(CommandSchema):
7675
"""
7776
model_config = ConfigDict(populate_by_name=True)
7877

79-
command_format: str = Field("application/swe+json", alias='commandFormat')
80-
encoding: SerializeAsAny[Encoding] = Field(...)
78+
# Literal pin powers the discriminated `AnyCommandSchema` union below
79+
# and removes the need for a runtime field_validator.
80+
command_format: Literal["application/swe+json"] = Field(
81+
"application/swe+json", alias='commandFormat')
82+
# Concrete subclass instead of `SerializeAsAny[Encoding]` — `JSONEncoding`
83+
# is the only Encoding type used in practice, and a concrete type
84+
# serializes deterministically without `SerializeAsAny`. If/when more
85+
# encoding types arrive, migrate this to a discriminated Union on
86+
# `Encoding.type`.
87+
encoding: JSONEncoding = Field(...)
8188
record_schema: AnyComponent = Field(..., alias='recordSchema')
8289

8390
@model_validator(mode="after")
@@ -140,17 +147,17 @@ class DatastreamRecordSchema(BaseModel):
140147
# docs/osh_spec_deviations.md (swe-json-missing-encoding).
141148
class SWEDatastreamRecordSchema(DatastreamRecordSchema):
142149
model_config = ConfigDict(populate_by_name=True)
143-
encoding: SerializeAsAny[Encoding] = Field(None)
150+
# Multi-Literal acts as the discriminator value(s) for AnyDatastreamRecordSchema
151+
# below. Replaces the previous runtime field_validator.
152+
obs_format: Literal[
153+
"application/swe+json",
154+
"application/swe+csv",
155+
"application/swe+text",
156+
"application/swe+binary",
157+
] = Field(..., alias='obsFormat')
158+
encoding: JSONEncoding = Field(None)
144159
record_schema: AnyComponent = Field(..., alias='recordSchema')
145160

146-
@field_validator('obs_format')
147-
@classmethod
148-
def check_check_obs_format(cls, v):
149-
if v not in [ObservationFormat.SWE_JSON.value, ObservationFormat.SWE_CSV.value,
150-
ObservationFormat.SWE_TEXT.value, ObservationFormat.SWE_BINARY.value]:
151-
raise ValueError('obsFormat must be on of the SWE formats')
152-
return v
153-
154161
@model_validator(mode="after")
155162
def _root_record_schema_requires_name(self):
156163
check_named(self.record_schema, "SWEDatastreamRecordSchema.recordSchema")
@@ -178,20 +185,15 @@ class OMJSONDatastreamRecordSchema(DatastreamRecordSchema):
178185
"""
179186
model_config = ConfigDict(populate_by_name=True)
180187

181-
obs_format: str = Field(ObservationFormat.JSON.value, alias='obsFormat')
188+
# Multi-Literal — both wire forms are spec-equivalent for OM+JSON.
189+
obs_format: Literal[
190+
"application/om+json",
191+
"application/json",
192+
] = Field("application/om+json", alias='obsFormat')
182193
result_schema: AnyComponent = Field(None, alias='resultSchema')
183194
parameters_schema: AnyComponent = Field(None, alias='parametersSchema')
184195
result_link: dict = Field(None, alias='resultLink')
185196

186-
@field_validator('obs_format')
187-
@classmethod
188-
def _check_obs_format(cls, v):
189-
if v not in (ObservationFormat.JSON.value, "application/json"):
190-
raise ValueError(
191-
f"obsFormat must be 'application/json' or '{ObservationFormat.JSON.value}'"
192-
)
193-
return v
194-
195197
@model_validator(mode="after")
196198
def _root_schemas_require_name(self):
197199
if self.result_schema is not None:
@@ -339,3 +341,30 @@ class SystemHistoryProperties(BaseModel):
339341
valid_time: list = Field(None)
340342
parent_system_link: str = Field(None, serialization_alias='parentSystem@link')
341343
procedure_link: str = Field(None, serialization_alias='procedure@link')
344+
345+
346+
# Discriminated unions replace the earlier `SerializeAsAny[<base>]` pattern
347+
# on resource models. Pydantic dispatches by the literal value of the
348+
# discriminator field — `obsFormat` / `commandFormat` — so validate and
349+
# dump round-trip without polymorphism quirks.
350+
AnyDatastreamRecordSchema = Annotated[
351+
Union[SWEDatastreamRecordSchema, OMJSONDatastreamRecordSchema],
352+
Field(discriminator='obs_format'),
353+
]
354+
"""Public alias for `DatastreamResource.record_schema`. Discriminator: `obs_format`."""
355+
356+
AnyCommandSchema = Annotated[
357+
Union[SWEJSONCommandSchema, JSONCommandSchema],
358+
Field(discriminator='command_format'),
359+
]
360+
"""Public alias for `ControlStreamResource.command_schema`. Discriminator: `command_format`."""
361+
362+
363+
# Defense-in-depth: rebuild every container model that forward-references
364+
# `AnyComponent`. See the matching block in swe_components.py for the
365+
# `MockValSer` rationale — same fault recurs here because each schema
366+
# class threads `AnyComponent` through its body.
367+
SWEJSONCommandSchema.model_rebuild(force=True)
368+
JSONCommandSchema.model_rebuild(force=True)
369+
SWEDatastreamRecordSchema.model_rebuild(force=True)
370+
OMJSONDatastreamRecordSchema.model_rebuild(force=True)

src/oshconnect/streamableresource.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,28 +1173,14 @@ def add_insert_datastream(self, datastream_schema: DatastreamResource):
11731173
component requires a name.
11741174
:return:
11751175
"""
1176-
print(f'Adding datastream: {datastream_schema.model_dump_json(exclude_none=True, by_alias=True)}')
1177-
# Make the request to add the datastream
1178-
# if successful, add the datastream to the system
1179-
# datastream_record_schema = SWEDatastreamRecordSchema(record_schema=datastream_schema,
1180-
# obs_format='application/swe+json', encoding=JSONEncoding())
1181-
# datastream_resource = DatastreamResource(ds_id="default", name=datastream_schema.name,
1182-
# output_name=datastream_schema.name,
1183-
# record_schema=datastream_record_schema,
1184-
# valid_time=TimePeriod(start=TimeInstant.now_as_time_instant(),
1185-
# end=TimeInstant(utc_time=TimeUtils.to_utc_time(
1186-
# "2026-12-31T00:00:00Z"))))
1187-
11881176
api = self._parent_node.get_api_helper()
1189-
# print(f'Attempting to create datastream: {datastream_resource.model_dump(by_alias=True, exclude_none=True)}')
11901177
res = api.create_resource(APIResourceTypes.DATASTREAM,
11911178
datastream_schema.model_dump_json(by_alias=True, exclude_none=True),
11921179
req_headers={'Content-Type': ContentTypes.JSON.value},
11931180
parent_res_id=self._resource_id)
11941181

11951182
if res.ok:
11961183
datastream_id = res.headers['Location'].split('/')[-1]
1197-
print(f'Resource Location: {datastream_id}')
11981184
datastream_schema.ds_id = datastream_id
11991185
else:
12001186
raise Exception(
@@ -1714,33 +1700,48 @@ def get_status_deque_outbound(self) -> deque:
17141700
return self._outbound_status_deque
17151701

17161702
def publish_command(self, payload):
1717-
"""Publish ``payload`` to the command MQTT topic. Convenience wrapper for ``publish(payload, 'command')``."""
1703+
"""Publish ``payload`` to the command MQTT topic. Convenience wrapper
1704+
for ``publish(payload, APIResourceTypes.COMMAND.value)``."""
17181705
self.publish(payload, topic=APIResourceTypes.COMMAND.value)
17191706

17201707
def publish_status(self, payload):
1721-
"""Publish ``payload`` to the status MQTT topic. Convenience wrapper for ``publish(payload, 'status')``."""
1708+
"""Publish ``payload`` to the status MQTT topic. Convenience wrapper
1709+
for ``publish(payload, APIResourceTypes.STATUS.value)``."""
17221710
self.publish(payload, topic=APIResourceTypes.STATUS.value)
17231711

1724-
def publish(self, payload, topic: str = 'command'):
1712+
def publish(self, payload, topic: str = APIResourceTypes.COMMAND.value):
17251713
"""
17261714
Publishes data to the MQTT topic associated with this control stream resource.
1727-
:param payload: Data to be published, subclass should determine specifically allowed types
1728-
:param topic: Specific implementation determines the topic from the provided string
1715+
1716+
:param payload: Data to be published; subclass determines specifically allowed types.
1717+
:param topic: One of ``APIResourceTypes.COMMAND.value`` (``"Command"``,
1718+
the default) or ``APIResourceTypes.STATUS.value`` (``"Status"``).
1719+
Pass the enum value rather than a lowercase shorthand — the
1720+
comparison is case-sensitive against the canonical CS API
1721+
resource-type strings.
17291722
"""
17301723

17311724
if topic == APIResourceTypes.COMMAND.value:
17321725
self._publish_mqtt(self._topic, payload)
17331726
elif topic == APIResourceTypes.STATUS.value:
17341727
self._publish_mqtt(self._status_topic, payload)
17351728
else:
1736-
raise ValueError(f"Unsupported topic type {topic} for ControlStream publish().")
1729+
raise ValueError(
1730+
f"Unsupported topic {topic!r} for ControlStream publish(); "
1731+
f"expected {APIResourceTypes.COMMAND.value!r} or "
1732+
f"{APIResourceTypes.STATUS.value!r}."
1733+
)
17371734

17381735
def subscribe(self, topic=None, callback=None, qos=0):
17391736
"""
17401737
Subscribes to the MQTT topic associated with this control stream resource.
1741-
:param topic: Specific implementation determines the topic from the provided string
1742-
:param callback: Optional callback function to handle incoming messages, if None the default handler is used
1743-
:param qos: Quality of Service level for the subscription, default is 0
1738+
1739+
:param topic: ``None`` (defaults to the command topic),
1740+
``APIResourceTypes.COMMAND.value`` (``"Command"``), or
1741+
``APIResourceTypes.STATUS.value`` (``"Status"``). Comparison is
1742+
case-sensitive against the canonical CS API resource-type strings.
1743+
:param callback: Optional callback function to handle incoming messages, if None the default handler is used.
1744+
:param qos: Quality of Service level for the subscription, default is 0.
17441745
"""
17451746

17461747
t = None
@@ -1750,7 +1751,11 @@ def subscribe(self, topic=None, callback=None, qos=0):
17501751
elif topic == APIResourceTypes.STATUS.value:
17511752
t = self._status_topic
17521753
else:
1753-
raise ValueError(f"Invalid topic provided {topic}, must be None or one of 'command' or 'status'.")
1754+
raise ValueError(
1755+
f"Invalid topic {topic!r}; must be None, "
1756+
f"{APIResourceTypes.COMMAND.value!r}, or "
1757+
f"{APIResourceTypes.STATUS.value!r}."
1758+
)
17541759

17551760
if callback is None:
17561761
self._mqtt_client.subscribe(t, qos=qos, msg_callback=self._mqtt_sub_callback)

src/oshconnect/swe_components.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from numbers import Real
1212
from typing import Union, Any, Literal, Annotated
1313

14-
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator, SerializeAsAny
14+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
1515

1616
from .csapi4py.constants import GeometryTypes
1717
from .api_utils import UCUMCode, URI
@@ -82,8 +82,7 @@ class VectorSchema(AnyComponentSchema):
8282
definition: str = Field(...)
8383
reference_frame: str = Field(..., alias='referenceFrame')
8484
local_frame: str = Field(None, alias='localFrame')
85-
# TODO: VERIFY might need to be moved further down when these are defined
86-
coordinates: SerializeAsAny[Union[list[CountSchema], list[QuantitySchema], list[TimeSchema]]] = Field(...)
85+
coordinates: Union[list[CountSchema], list[QuantitySchema], list[TimeSchema]] = Field(...)
8786

8887
@model_validator(mode="after")
8988
def _coordinates_require_name(self):
@@ -273,3 +272,17 @@ class CategoryRangeSchema(AnySimpleComponentSchema):
273272
],
274273
Field(discriminator="type"),
275274
]
275+
276+
277+
# Rebuild every container model that forward-references AnyComponent.
278+
# Without this, pydantic leaves a `MockValSer` placeholder on the
279+
# serializer side — `model_validate` upgrades the validator, but
280+
# `model_dump`/`model_dump_json` raise
281+
# `TypeError: 'MockValSer' object is not an instance of 'SchemaSerializer'`.
282+
# Plain `model_rebuild()` is a no-op (the class reports `model_complete`),
283+
# so `force=True` is required.
284+
DataRecordSchema.model_rebuild(force=True)
285+
VectorSchema.model_rebuild(force=True)
286+
DataArraySchema.model_rebuild(force=True)
287+
MatrixSchema.model_rebuild(force=True)
288+
DataChoiceSchema.model_rebuild(force=True)

0 commit comments

Comments
 (0)