Skip to content

Commit d9eb817

Browse files
fix: UTC-539: Handle Labels split format in prediction validations (#755)
Co-authored-by: robot-ci-heartex <robot-ci-heartex@users.noreply.github.com>
1 parent 425049d commit d9eb817

3 files changed

Lines changed: 324 additions & 9 deletions

File tree

src/label_studio_sdk/label_interface/control_tags.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def _validate_value_labels(self, value):
298298
return self._validate_labels(value.get(self._label_attr_name))
299299
return False
300300

301-
def validate_value(self, value: dict) -> bool:
301+
def validate_value(self, value: dict, context: Optional[dict] = None) -> bool:
302302
"""
303303
Given "value" from [annotation result format](https://labelstud.io/guide/task_format),
304304
validate if it's a valid value for this control tag.
@@ -311,6 +311,11 @@ def validate_value(self, value: dict) -> bool:
311311
```python
312312
RectangleTag(name="rect", to_name=["img"], tag="rectangle", attr={}).validate_value({"x": 10, "y": 10, "width": 10, "height": 10, "rotation": 10})
313313
```
314+
context : dict, optional
315+
Additional context about sibling regions. May contain
316+
``result`` (the full result list) and ``region`` (the current
317+
region dict). Subclasses can inspect this to adjust
318+
validation behaviour.
314319
315320
Returns:
316321
--------
@@ -567,6 +572,79 @@ class LabelsTag(ControlTag):
567572
_label_attr_name: str = "labels"
568573
_value_class: Type[LabelsValue] = LabelsValue
569574

575+
def _get_split_partner_value_class(self, context: dict) -> Optional[Type[BaseModel]]:
576+
"""Return the geometry value class if this region is a split-format labels companion.
577+
578+
In split format a geometry result (rectangle, polygon, …) and a
579+
separate labels result share the same ``id``. When detected, this
580+
method returns the geometry partner's pydantic value class so the
581+
labels value can be validated against the expected geometry fields.
582+
583+
Returns ``None`` when the region is not part of a split-format pair.
584+
"""
585+
# Lazy import avoided: the mapping references value classes defined
586+
# later in this file, but they are available at call time.
587+
geometry_type_to_value_class: Dict[str, Type[BaseModel]] = {
588+
'rectangle': RectangleValue,
589+
'polygon': PolygonValue,
590+
'ellipse': EllipseValue,
591+
'keypoint': KeyPointValue,
592+
'brush': BrushValue,
593+
}
594+
595+
region = context.get('region', {})
596+
result = context.get('result', [])
597+
598+
# Only check for split-format pairs if the current region is a labels region
599+
if (region.get('type') or '').lower() != 'labels':
600+
return None
601+
602+
# Check if the current region has an id
603+
region_id = region.get('id')
604+
if region_id is None:
605+
return None
606+
607+
# Iterate through all regions to find a matching geometry partner
608+
for r in result:
609+
if not isinstance(r, dict):
610+
continue
611+
r_type = (r.get('type') or '').lower()
612+
# Checks if sibling region has the same id and is a geometry type
613+
if r.get('id') == region_id and r_type in geometry_type_to_value_class:
614+
# Return the geometry value class to validate the attributes of the current region
615+
return geometry_type_to_value_class[r_type]
616+
617+
return None
618+
619+
def validate_value(self, value: dict, context: Optional[dict] = None) -> bool:
620+
"""Validate value, selecting the pydantic model based on format.
621+
622+
When the ``context`` indicates this region is the labels companion
623+
of a split-format pair the partner's geometry value class (e.g.
624+
:class:`RectangleValue`, :class:`PolygonValue`) is used to
625+
validate the geometry fields. The ``labels`` subset is validated
626+
separately via :meth:`_validate_value_labels`.
627+
628+
Otherwise the strict :class:`LabelsValue` model is used
629+
(``start`` and ``end`` required alongside ``labels``).
630+
"""
631+
if isinstance(value, dict) and self.name in value and isinstance(value[self.name], dict):
632+
value = value[self.name]
633+
634+
if not self._validate_value_labels(value):
635+
return False
636+
637+
geometry_value_class = (
638+
self._get_split_partner_value_class(context) if context else None
639+
)
640+
# Use the geometry partner's value class for split format,
641+
# otherwise the default text-span LabelsValue.
642+
value_class = geometry_value_class or self._value_class
643+
try:
644+
value_class(**value)
645+
return True
646+
except Exception:
647+
return False
570648

571649
def to_json_schema(self):
572650
"""
@@ -961,7 +1039,7 @@ class RankerTag(ControlTag):
9611039
tag: str = "Ranker"
9621040
_value_class: Type[RankerValue] = RankerValue
9631041

964-
def validate_value(self, value: dict) -> bool:
1042+
def validate_value(self, value: dict, **kwargs) -> bool:
9651043
"""
9661044
Accept only:
9671045
- {"ranker": {"<control_tag_name>": [str, ...]}}
@@ -1069,7 +1147,7 @@ def to_json_schema(self):
10691147
class RelationsTag(ControlTag):
10701148
""" """
10711149
tag: str = "Relations"
1072-
def validate_value(self, ) -> bool:
1150+
def validate_value(self, **kwargs) -> bool:
10731151
""" """
10741152
raise NotImplemented("""Should not be called directly, instead
10751153
use validate_relation() method found in LabelInterface class""")

src/label_studio_sdk/label_interface/interface.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,10 @@ def _validate_object_logic(self, obj) -> List[str]:
832832
continue
833833

834834
if region.get('type') != "relation":
835-
region_errors = self.validate_region(region, return_errors=True, region_index=i)
835+
region_errors = self.validate_region(
836+
region, return_errors=True, region_index=i,
837+
context={'result': result, 'region': region},
838+
)
836839
errors.extend(region_errors)
837840
if not region_errors: # Only add to regions if no errors
838841
regions.append(region)
@@ -886,12 +889,14 @@ def validate_prediction(self, prediction, return_errors=False):
886889
"""
887890
return self._validate_object(prediction, return_errors)
888891

889-
def _validate_region_logic(self, region, region_index=0) -> Tuple[bool, List[str]]:
892+
def _validate_region_logic(self, region, region_index=0, context=None) -> Tuple[bool, List[str]]:
890893
"""Helper method to perform region validation logic.
891894
892895
Args:
893896
region (dict): The region to be validated.
894897
region_index (int): Index of the region for error reporting.
898+
context (dict, optional): Additional context forwarded to
899+
``control.validate_value()``.
895900
896901
Returns:
897902
tuple: (is_valid, errors) where is_valid is bool and errors is list of strings.
@@ -934,7 +939,7 @@ def _validate_region_logic(self, region, region_index=0) -> Tuple[bool, List[str
934939

935940
# Validate the value using control's validate_value method
936941
try:
937-
if not control.validate_value(region["value"]):
942+
if not control.validate_value(region["value"], context=context):
938943
# Prefer a clearer message for rectangle geometry bounds
939944
tag_lower = getattr(control, 'tag', '').lower()
940945
if tag_lower in ('rectangle', 'rectanglelabels'):
@@ -985,7 +990,7 @@ def _get_valid_values_for_control(self, control):
985990
except Exception:
986991
return "unknown validation rules"
987992

988-
def validate_region(self, region, return_errors=False, region_index=0):
993+
def validate_region(self, region, return_errors=False, region_index=0, context=None):
989994
"""Validates a region from the annotation against the current
990995
configuration.
991996
@@ -999,12 +1004,16 @@ def validate_region(self, region, return_errors=False, region_index=0):
9991004
region (dict): The region to be validated.
10001005
return_errors (bool): If True, returns a list of error messages instead of boolean
10011006
region_index (int): Index of the region for error reporting (used when return_errors=True)
1007+
context (dict, optional): Additional context passed through to
1008+
``control.validate_value()``. May contain ``result`` (the
1009+
full result list) and ``region`` (the current region dict)
1010+
so that control tags can inspect sibling regions.
10021011
10031012
Returns:
10041013
Union[bool, List[str]]: If return_errors=False, returns True/False.
10051014
If return_errors=True, returns list of error messages.
10061015
"""
1007-
is_valid, errors = self._validate_region_logic(region, region_index)
1016+
is_valid, errors = self._validate_region_logic(region, region_index, context=context)
10081017

10091018
if return_errors:
10101019
return errors

0 commit comments

Comments
 (0)