Skip to content

Commit 070bac0

Browse files
committed
Update: modify 3 file(s)
1 parent a8f1a78 commit 070bac0

3 files changed

Lines changed: 49 additions & 32 deletions

File tree

src/stickler/structured_object_evaluator/models/json_schema_field_converter.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -266,26 +266,6 @@ def _extract_stickler_extensions(
266266

267267
return extensions
268268

269-
def _create_comparator_from_name(self, comparator_name: str):
270-
"""Create a comparator instance from its class name.
271-
272-
Args:
273-
comparator_name: Name of the comparator class
274-
275-
Returns:
276-
Comparator instance
277-
278-
Raises:
279-
ValueError: If comparator name is not registered
280-
"""
281-
# Use existing comparator_registry
282-
try:
283-
return create_comparator(comparator_name, {})
284-
except KeyError as e:
285-
# The KeyError message from the registry already contains the list of valid comparators
286-
# Re-raise as ValueError with the same information
287-
raise ValueError(str(e)) from e
288-
289269
def _resolve_ref(self, ref: str) -> Dict[str, Any]:
290270
"""Resolve a $ref reference within the schema.
291271
@@ -471,7 +451,7 @@ def field_to_property(self, field_type: Type, field_info: FieldInfo) -> Dict[str
471451

472452
# Extract metadata and build extensions using consolidated helper
473453
metadata = self._extract_field_metadata(field_info)
474-
extensions = self._build_comparison_extensions(metadata, format="json_schema")
454+
extensions = self._build_comparison_extensions(metadata, output_format="json_schema")
475455
property_schema.update(extensions)
476456

477457
# Add Pydantic field params
@@ -502,7 +482,7 @@ def field_to_stickler_config(self, field_type: Type, field_info: FieldInfo) -> D
502482

503483
# Extract metadata and build extensions using consolidated helper
504484
metadata = self._extract_field_metadata(field_info)
505-
extensions = self._build_comparison_extensions(metadata, format="stickler_config")
485+
extensions = self._build_comparison_extensions(metadata, output_format="stickler_config")
506486
field_config.update(extensions)
507487

508488
# Add Pydantic field params
@@ -521,23 +501,23 @@ def field_to_stickler_config(self, field_type: Type, field_info: FieldInfo) -> D
521501
def _build_comparison_extensions(
522502
self,
523503
metadata: Dict[str, Any],
524-
format: str = "json_schema"
504+
output_format: str = "json_schema"
525505
) -> Dict[str, Any]:
526506
"""Build comparison extensions in specified format.
527507
528508
Consolidates duplicate logic from field_to_property() and field_to_stickler_config().
529509
530510
Args:
531511
metadata: Extracted field metadata from _extract_field_metadata()
532-
format: Output format - "json_schema" or "stickler_config"
512+
output_format: Output format - "json_schema" or "stickler_config"
533513
534514
Returns:
535515
Dictionary with comparison extensions in the specified format
536516
"""
537517
extensions = {}
538-
if format not in ("json_schema", "stickler_config"):
539-
raise ValueError(f"Unsupported format: {format!r}. Use 'json_schema' or 'stickler_config'.")
540-
prefix = "x-aws-stickler-" if format == "json_schema" else ""
518+
if output_format not in ("json_schema", "stickler_config"):
519+
raise ValueError(f"Unsupported format: {output_format!r}. Use 'json_schema' or 'stickler_config'.")
520+
prefix = "x-aws-stickler-" if output_format == "json_schema" else ""
541521

542522
# Export comparator class name and configuration
543523
if metadata.get("comparator"):
@@ -546,7 +526,7 @@ def _build_comparison_extensions(
546526

547527
# Export comparator configuration (e.g., tolerance, case_sensitive)
548528
if hasattr(comparator, "config") and comparator.config:
549-
config_key = f"{prefix}comparator-config" if format == "json_schema" else "comparator_config"
529+
config_key = f"{prefix}comparator-config" if output_format == "json_schema" else "comparator_config"
550530
extensions[config_key] = comparator.config
551531

552532
# Export comparison parameters
@@ -555,7 +535,7 @@ def _build_comparison_extensions(
555535
if "weight" in metadata:
556536
extensions[f"{prefix}weight"] = metadata["weight"]
557537
if metadata.get("clip_under_threshold") is not None:
558-
clip_key = f"{prefix}clip-under-threshold" if format == "json_schema" else "clip_under_threshold"
538+
clip_key = f"{prefix}clip-under-threshold" if output_format == "json_schema" else "clip_under_threshold"
559539
extensions[clip_key] = metadata["clip_under_threshold"]
560540
if metadata.get("aggregate") is not None:
561541
extensions[f"{prefix}aggregate"] = metadata["aggregate"]

src/stickler/structured_object_evaluator/models/structured_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,7 @@ def to_json_schema(cls) -> Dict[str, Any]:
12481248
# Extract and add stickler extensions from field metadata
12491249
metadata = converter._extract_field_metadata(field_info)
12501250
extensions = converter._build_comparison_extensions(
1251-
metadata, format="json_schema"
1251+
metadata, output_format="json_schema"
12521252
)
12531253
property_schema.update(extensions)
12541254
else:
@@ -1377,10 +1377,11 @@ def to_stickler_config(cls) -> Dict[str, Any]:
13771377
if nested_config.get("match_threshold") is not None:
13781378
field_config["match_threshold"] = nested_config["match_threshold"]
13791379
else:
1380-
# Primitive list - use converter
1380+
# Primitive list - pass element type, then fix up type string
13811381
field_config = converter.field_to_stickler_config(
1382-
field_type, field_info
1382+
element_type, field_info
13831383
)
1384+
field_config["type"] = f"List[{element_type.__name__}]"
13841385
else:
13851386
# Primitive type - use converter for consistent formatting
13861387
field_config = converter.field_to_stickler_config(

tests/structured_object_evaluator/test_model_export.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ class ListModel(StructuredModel):
4040
products: List[SimpleProduct] = ComparableField()
4141

4242

43+
class PrimitiveListModel(StructuredModel):
44+
"""Model with List[primitive] for testing primitive list export."""
45+
tags: List[str] = ComparableField(threshold=0.8)
46+
scores: List[int] = ComparableField(threshold=0.9)
47+
48+
4349
def test_to_json_schema_basic():
4450
"""Test exporting simple model to JSON Schema format."""
4551
schema = SimpleProduct.to_json_schema()
@@ -167,6 +173,36 @@ def test_to_stickler_config_list():
167173
assert "price" in products_field["fields"]
168174

169175

176+
def test_to_json_schema_primitive_list():
177+
"""Test exporting model with List[str] and List[int] to JSON Schema."""
178+
schema = PrimitiveListModel.to_json_schema()
179+
180+
# List[str] should export as array with string items
181+
tags_prop = schema["properties"]["tags"]
182+
assert tags_prop["type"] == "array"
183+
assert tags_prop["items"]["type"] == "string"
184+
185+
# List[int] should export as array with integer items
186+
scores_prop = schema["properties"]["scores"]
187+
assert scores_prop["type"] == "array"
188+
assert scores_prop["items"]["type"] == "integer"
189+
190+
191+
def test_to_stickler_config_primitive_list():
192+
"""Test exporting model with List[str] and List[int] to Stickler config."""
193+
config = PrimitiveListModel.to_stickler_config()
194+
195+
# List[str] should export with list-aware type, not plain "str"
196+
tags_field = config["fields"]["tags"]
197+
assert tags_field["type"] == "List[str]"
198+
assert tags_field["threshold"] == 0.8
199+
200+
# List[int] should export with list-aware type, not plain "int"
201+
scores_field = config["fields"]["scores"]
202+
assert scores_field["type"] == "List[int]"
203+
assert scores_field["threshold"] == 0.9
204+
205+
170206
def test_export_preserves_metadata():
171207
"""Test that all comparison metadata is preserved in export."""
172208

0 commit comments

Comments
 (0)