Skip to content

Commit 8e79d30

Browse files
committed
Update: modify 6 file(s)
1 parent 3c15b7b commit 8e79d30

6 files changed

Lines changed: 144 additions & 18 deletions

File tree

docs/docs/Guides/StructuredModel_Export.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,4 +405,4 @@ git commit -m "Add Product model schema v1"
405405

406406
- [StructuredModel Dynamic Creation](StructuredModel_Dynamic_Creation.md) - Import methods
407407
- [StructuredModel Advanced Functionality](StructuredModel_Advanced_Functionality.md) - Comparison features
408-
- [JSON Schema Extensions](../../index.md) - Full extension documentation in main README
408+
- [JSON Schema Extensions](../Getting-Started/README.md) - Full extension documentation

src/stickler/comparators/numeric.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from decimal import Decimal, InvalidOperation
5-
from typing import Any, Optional, Union
5+
from typing import Any, Dict, Optional, Union
66

77
from stickler.comparators.base import BaseComparator
88

@@ -55,6 +55,16 @@ def __init__(
5555
else:
5656
self.absolute_tolerance = absolute_tolerance
5757

58+
@property
59+
def config(self) -> Optional[Dict[str, Any]]:
60+
"""Return configuration parameters for serialization."""
61+
config = {}
62+
if self.relative_tolerance != 0.0:
63+
config["relative_tolerance"] = self.relative_tolerance
64+
if self.absolute_tolerance != 0.0:
65+
config["absolute_tolerance"] = self.absolute_tolerance
66+
return config or None
67+
5868
def compare(self, str1: Any, str2: Any) -> float:
5969
"""Compare two values numerically.
6070

src/stickler/structured_object_evaluator/models/json_schema_field_converter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,9 @@ def _extract_stickler_extensions(
216216
# Extract comparator
217217
if "x-aws-stickler-comparator" in property_schema:
218218
comparator_name = property_schema["x-aws-stickler-comparator"]
219+
comparator_config = property_schema.get("x-aws-stickler-comparator-config", {})
219220
try:
220-
extensions["comparator"] = self._create_comparator_from_name(comparator_name)
221+
extensions["comparator"] = create_comparator(comparator_name, comparator_config)
221222
except Exception as e:
222223
field_info = f" in field '{field_path}'" if field_path else ""
223224
raise ValueError(
@@ -534,6 +535,8 @@ def _build_comparison_extensions(
534535
Dictionary with comparison extensions in the specified format
535536
"""
536537
extensions = {}
538+
if format not in ("json_schema", "stickler_config"):
539+
raise ValueError(f"Unsupported format: {format!r}. Use 'json_schema' or 'stickler_config'.")
537540
prefix = "x-aws-stickler-" if format == "json_schema" else ""
538541

539542
# Export comparator class name and configuration

src/stickler/structured_object_evaluator/models/structured_model.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,7 +1194,9 @@ def to_json_schema(cls) -> Dict[str, Any]:
11941194
}
11951195

11961196
# Add match_threshold if available (check both attribute names for compatibility)
1197-
threshold = getattr(cls, "match_threshold", None) or getattr(cls, "_match_threshold", None)
1197+
threshold = getattr(cls, "match_threshold", None)
1198+
if threshold is None:
1199+
threshold = getattr(cls, "_match_threshold", None)
11981200
if threshold is not None:
11991201
schema["x-aws-stickler-match-threshold"] = threshold
12001202

@@ -1207,6 +1209,7 @@ def to_json_schema(cls) -> Dict[str, Any]:
12071209

12081210
# Validate field has type annotation
12091211
if field_type is None:
1212+
# Defensive: unreachable through normal Pydantic model construction
12101213
raise ValueError(f"Field '{field_name}' has no type annotation")
12111214

12121215
# Unwrap Optional before type checking
@@ -1219,6 +1222,7 @@ def to_json_schema(cls) -> Dict[str, Any]:
12191222
# Handle List[StructuredModel] or List[primitive]
12201223
args = get_args(field_type)
12211224
if not args:
1225+
# Defensive: unreachable through normal Pydantic model construction
12221226
raise ValueError(
12231227
f"Field '{field_name}' has unparameterized list type. "
12241228
f"Use List[str], List[int], etc."
@@ -1321,7 +1325,9 @@ def to_stickler_config(cls) -> Dict[str, Any]:
13211325
}
13221326

13231327
# Add match_threshold if available (check both attribute names for compatibility)
1324-
threshold = getattr(cls, "match_threshold", None) or getattr(cls, "_match_threshold", None)
1328+
threshold = getattr(cls, "match_threshold", None)
1329+
if threshold is None:
1330+
threshold = getattr(cls, "_match_threshold", None)
13251331
if threshold is not None:
13261332
config["match_threshold"] = threshold
13271333

@@ -1334,35 +1340,38 @@ def to_stickler_config(cls) -> Dict[str, Any]:
13341340

13351341
# Validate field has type annotation
13361342
if field_type is None:
1343+
# Defensive: unreachable through normal Pydantic model construction
13371344
raise ValueError(f"Field '{field_name}' has no type annotation")
13381345

13391346
# Unwrap Optional before type checking
13401347
field_type, _ = cls._unwrap_optional(field_type)
13411348

13421349
# Check if nested StructuredModel - use "structured_model" type
13431350
if cls._is_structured_model_type(field_type):
1344-
field_config = {
1345-
"type": "structured_model",
1346-
# Recursively export nested model's fields
1347-
"fields": field_type.to_stickler_config()["fields"]
1348-
}
1351+
nested_config = field_type.to_stickler_config()
1352+
field_config = {"type": "structured_model", "fields": nested_config["fields"]}
1353+
if nested_config.get("model_name"):
1354+
field_config["model_name"] = nested_config["model_name"]
1355+
if nested_config.get("match_threshold") is not None:
1356+
field_config["match_threshold"] = nested_config["match_threshold"]
13491357
elif get_origin(field_type) is list:
13501358
# Handle List[StructuredModel] or List[primitive]
13511359
args = get_args(field_type)
13521360
if not args:
1361+
# Defensive: unreachable through normal Pydantic model construction
13531362
raise ValueError(
13541363
f"Field '{field_name}' has unparameterized list type. "
13551364
f"Use List[str], List[int], etc."
13561365
)
13571366
element_type = args[0]
13581367

13591368
if cls._is_structured_model_type(element_type):
1360-
# List of StructuredModels - use "list_structured_model" type
1361-
field_config = {
1362-
"type": "list_structured_model",
1363-
# Recursively export element model's fields
1364-
"fields": element_type.to_stickler_config()["fields"]
1365-
}
1369+
nested_config = element_type.to_stickler_config()
1370+
field_config = {"type": "list_structured_model", "fields": nested_config["fields"]}
1371+
if nested_config.get("model_name"):
1372+
field_config["model_name"] = nested_config["model_name"]
1373+
if nested_config.get("match_threshold") is not None:
1374+
field_config["match_threshold"] = nested_config["match_threshold"]
13661375
else:
13671376
# Primitive list - use converter
13681377
field_config = converter.field_to_stickler_config(field_type, field_info)

tests/structured_object_evaluator/test_export_roundtrip.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
preserving their comparison behavior.
55
"""
66

7-
from typing import List
7+
from typing import List, Optional
88

99
from stickler.comparators.levenshtein import LevenshteinComparator
1010
from stickler.comparators.numeric import NumericComparator
@@ -218,3 +218,52 @@ def test_multiple_roundtrips():
218218
result2 = p2.compare_with(p2)
219219

220220
assert result_orig["overall_score"] == result1["overall_score"] == result2["overall_score"] == 1.0
221+
222+
223+
def test_optional_field_roundtrip():
224+
"""Test round-trip with Optional fields preserves comparison behavior."""
225+
226+
class OptionalModel(StructuredModel):
227+
name: str = ComparableField(threshold=0.8, default=...)
228+
note: Optional[str] = ComparableField(threshold=0.6, default=None)
229+
230+
# JSON Schema round-trip
231+
schema = OptionalModel.to_json_schema()
232+
Reconstructed = StructuredModel.from_json_schema(schema)
233+
234+
# Test with non-None values (Optional nature is not preserved in round-trip — known limitation)
235+
o1 = OptionalModel(name="Test", note="hello")
236+
r1 = Reconstructed(name="Test", note="hello")
237+
238+
result_orig = o1.compare_with(o1)
239+
result_recon = r1.compare_with(r1)
240+
241+
assert result_orig["overall_score"] == 1.0
242+
assert result_recon["overall_score"] == 1.0
243+
244+
245+
def test_numeric_comparator_tolerance_roundtrip():
246+
"""Test that NumericComparator tolerance is preserved after round-trip."""
247+
248+
class TolerantModel(StructuredModel):
249+
value: float = ComparableField(
250+
comparator=NumericComparator(absolute_tolerance=0.5),
251+
threshold=1.0,
252+
default=...
253+
)
254+
255+
# JSON Schema round-trip
256+
schema = TolerantModel.to_json_schema()
257+
Reconstructed = StructuredModel.from_json_schema(schema)
258+
r1 = Reconstructed(value=100.0)
259+
r2 = Reconstructed(value=100.3)
260+
result = r1.compare_with(r2)
261+
assert result["field_scores"]["value"] == 1.0
262+
263+
# Stickler config round-trip
264+
config = TolerantModel.to_stickler_config()
265+
Reconstructed2 = StructuredModel.model_from_json(config)
266+
r3 = Reconstructed2(value=100.0)
267+
r4 = Reconstructed2(value=100.3)
268+
result2 = r3.compare_with(r4)
269+
assert result2["field_scores"]["value"] == 1.0

tests/structured_object_evaluator/test_model_export.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
that export StructuredModel configurations for serialization.
55
"""
66

7-
from typing import List
7+
from typing import List, Optional
88

99
from stickler.comparators.levenshtein import LevenshteinComparator
1010
from stickler.comparators.numeric import NumericComparator
@@ -193,3 +193,58 @@ class DetailedModel(StructuredModel):
193193
assert field_config["weight"] == 1.5
194194
assert field_config["clip_under_threshold"] is False
195195
assert field_config["aggregate"] is True
196+
197+
198+
class OptionalFieldModel(StructuredModel):
199+
"""Model with Optional fields for testing export."""
200+
required_name: str = ComparableField(threshold=0.8, default=...)
201+
optional_note: Optional[str] = ComparableField(threshold=0.6, default=None)
202+
optional_product: Optional[SimpleProduct] = ComparableField(default=None)
203+
optional_products: Optional[List[SimpleProduct]] = ComparableField(default=None)
204+
205+
206+
def test_to_json_schema_optional_fields():
207+
"""Test that Optional fields export correctly and are not in required list."""
208+
schema = OptionalFieldModel.to_json_schema()
209+
210+
# Required field is in required list
211+
assert "required_name" in schema["required"]
212+
213+
# Optional fields are NOT in required list
214+
assert "optional_note" not in schema["required"]
215+
assert "optional_product" not in schema["required"]
216+
assert "optional_products" not in schema["required"]
217+
218+
# Optional[str] unwraps to string type
219+
assert schema["properties"]["optional_note"]["type"] == "string"
220+
221+
# Optional[StructuredModel] unwraps to nested object schema
222+
product_prop = schema["properties"]["optional_product"]
223+
assert product_prop["type"] == "object"
224+
assert "name" in product_prop["properties"]
225+
assert "price" in product_prop["properties"]
226+
227+
# Optional[List[StructuredModel]] unwraps to array with nested items
228+
products_prop = schema["properties"]["optional_products"]
229+
assert products_prop["type"] == "array"
230+
assert products_prop["items"]["type"] == "object"
231+
assert "name" in products_prop["items"]["properties"]
232+
233+
234+
def test_to_stickler_config_optional_fields():
235+
"""Test that Optional fields export correctly in Stickler config format."""
236+
config = OptionalFieldModel.to_stickler_config()
237+
238+
# Optional[str] unwraps to str type
239+
assert config["fields"]["optional_note"]["type"] == "str"
240+
241+
# Optional[StructuredModel] unwraps to structured_model type
242+
product_field = config["fields"]["optional_product"]
243+
assert product_field["type"] == "structured_model"
244+
assert "name" in product_field["fields"]
245+
assert "price" in product_field["fields"]
246+
247+
# Optional[List[StructuredModel]] unwraps to list_structured_model type
248+
products_field = config["fields"]["optional_products"]
249+
assert products_field["type"] == "list_structured_model"
250+
assert "name" in products_field["fields"]

0 commit comments

Comments
 (0)