11import json
22from dataclasses import dataclass
33from enum import Enum
4- from typing import Any , Dict , List , Optional , Sequence , Union
4+ from typing import Dict , List , Optional , Sequence , Union
55from nucleus .dataset_item import is_local_path
66
77from .constants import (
@@ -174,11 +174,23 @@ def to_payload(self) -> dict:
174174 }
175175
176176
177- # TODO: Add Generic type for 2D point
177+ @dataclass
178+ class Point :
179+ x : float
180+ y : float
181+
182+ @classmethod
183+ def from_json (cls , payload : Dict [str , float ]):
184+ return cls (payload [X_KEY ], payload [Y_KEY ])
185+
186+ def to_payload (self ) -> dict :
187+ return {X_KEY : self .x , Y_KEY : self .y }
188+
189+
178190@dataclass
179191class PolygonAnnotation (Annotation ):
180192 label : str
181- vertices : List [Any ]
193+ vertices : List [Point ]
182194 reference_id : Optional [str ] = None
183195 item_id : Optional [str ] = None
184196 annotation_id : Optional [str ] = None
@@ -187,28 +199,46 @@ class PolygonAnnotation(Annotation):
187199 def __post_init__ (self ):
188200 self ._check_ids ()
189201 self .metadata = self .metadata if self .metadata else {}
202+ if len (self .vertices ) > 0 :
203+ if not hasattr (self .vertices [0 ], X_KEY ) or not hasattr (
204+ self .vertices [0 ], "to_payload"
205+ ):
206+ try :
207+ self .vertices = [
208+ Point (x = vertex [X_KEY ], y = vertex [Y_KEY ])
209+ for vertex in self .vertices
210+ ]
211+ except KeyError as ke :
212+ raise ValueError (
213+ "Use a point object to pass in vertices. For example, vertices=[nucleus.Point(x=1, y=2)]"
214+ ) from ke
190215
191216 @classmethod
192217 def from_json (cls , payload : dict ):
193218 geometry = payload .get (GEOMETRY_KEY , {})
194219 return cls (
195220 label = payload .get (LABEL_KEY , 0 ),
196- vertices = geometry .get (VERTICES_KEY , []),
221+ vertices = [
222+ Point .from_json (_ ) for _ in geometry .get (VERTICES_KEY , [])
223+ ],
197224 reference_id = payload .get (REFERENCE_ID_KEY , None ),
198225 item_id = payload .get (DATASET_ITEM_ID_KEY , None ),
199226 annotation_id = payload .get (ANNOTATION_ID_KEY , None ),
200227 metadata = payload .get (METADATA_KEY , {}),
201228 )
202229
203230 def to_payload (self ) -> dict :
204- return {
231+ payload = {
205232 LABEL_KEY : self .label ,
206233 TYPE_KEY : POLYGON_TYPE ,
207- GEOMETRY_KEY : {VERTICES_KEY : self .vertices },
234+ GEOMETRY_KEY : {
235+ VERTICES_KEY : [_ .to_payload () for _ in self .vertices ]
236+ },
208237 REFERENCE_ID_KEY : self .reference_id ,
209238 ANNOTATION_ID_KEY : self .annotation_id ,
210239 METADATA_KEY : self .metadata ,
211240 }
241+ return payload
212242
213243
214244def check_all_annotation_paths_remote (
0 commit comments