2121 ANNOTATIONS_KEY ,
2222)
2323
24+ from dataclasses import dataclass
25+
2426
2527class Annotation :
28+ def _check_ids (self ):
29+ if bool (self .reference_id ) == bool (self .item_id ):
30+ raise Exception (
31+ "You must specify either a reference_id or an item_id for an annotation."
32+ )
33+
2634 @classmethod
2735 def from_json (cls , payload : dict ):
2836 if payload .get (TYPE_KEY , None ) == BOX_TYPE :
@@ -33,16 +41,11 @@ def from_json(cls, payload: dict):
3341 return SegmentationAnnotation .from_json (payload )
3442
3543
44+ @dataclass
3645class Segment :
37- def __init__ (
38- self , label : str , index : int , metadata : Optional [dict ] = None
39- ):
40- self .label = label
41- self .index = index
42- self .metadata = metadata
43-
44- def __str__ (self ):
45- return str (self .to_payload ())
46+ label : str
47+ index : int
48+ metadata : Optional [dict ] = None
4649
4750 @classmethod
4851 def from_json (cls , payload : dict ):
@@ -62,35 +65,25 @@ def to_payload(self) -> dict:
6265 return payload
6366
6467
68+ @dataclass
6569class SegmentationAnnotation (Annotation ):
66- def __init__ (
67- self ,
68- mask_url : str ,
69- annotations : List [Segment ],
70- reference_id : Optional [str ] = None ,
71- item_id : Optional [str ] = None ,
72- annotation_id : Optional [str ] = None ,
73- ):
74- super ().__init__ ()
75- if not mask_url :
70+ mask_url : str
71+ annotations : List [Segment ]
72+ reference_id : Optional [str ] = None
73+ item_id : Optional [str ] = None
74+ annotation_id : Optional [str ] = None
75+
76+ def __post_init__ (self ):
77+ if not self .mask_url :
7678 raise Exception ("You must specify a mask_url." )
77- if bool (reference_id ) == bool (item_id ):
78- raise Exception (
79- "You must specify either a reference_id or an item_id for an annotation."
80- )
81- self .mask_url = mask_url
82- self .annotations = annotations
83- self .reference_id = reference_id
84- self .item_id = item_id
85- self .annotation_id = annotation_id
86-
87- def __str__ (self ):
88- return str (self .to_payload ())
79+ self ._check_ids ()
8980
9081 @classmethod
9182 def from_json (cls , payload : dict ):
83+ if MASK_URL_KEY not in payload :
84+ raise ValueError (f"Missing { MASK_URL_KEY } in json" )
9285 return cls (
93- mask_url = payload . get ( MASK_URL_KEY ) ,
86+ mask_url = payload [ MASK_URL_KEY ] ,
9487 annotations = [
9588 Segment .from_json (ann )
9689 for ann in payload .get (ANNOTATIONS_KEY , [])
@@ -118,35 +111,21 @@ class AnnotationTypes(Enum):
118111 POLYGON = POLYGON_TYPE
119112
120113
121- # TODO: Add base annotation class to reduce repeated code here
114+ @ dataclass
122115class BoxAnnotation (Annotation ):
123- # pylint: disable=too-many-instance-attributes
124- def __init__ (
125- self ,
126- label : str ,
127- x : Union [float , int ],
128- y : Union [float , int ],
129- width : Union [float , int ],
130- height : Union [float , int ],
131- reference_id : Optional [str ] = None ,
132- item_id : Optional [str ] = None ,
133- annotation_id : Optional [str ] = None ,
134- metadata : Optional [Dict ] = None ,
135- ):
136- super ().__init__ ()
137- if bool (reference_id ) == bool (item_id ):
138- raise Exception (
139- "You must specify either a reference_id or an item_id for an annotation."
140- )
141- self .label = label
142- self .x = x
143- self .y = y
144- self .width = width
145- self .height = height
146- self .reference_id = reference_id
147- self .item_id = item_id
148- self .annotation_id = annotation_id
149- self .metadata = metadata if metadata else {}
116+ label : str
117+ x : Union [float , int ]
118+ y : Union [float , int ]
119+ width : Union [float , int ]
120+ height : Union [float , int ]
121+ reference_id : Optional [str ] = None
122+ item_id : Optional [str ] = None
123+ annotation_id : Optional [str ] = None
124+ metadata : Optional [Dict ] = None
125+
126+ def __post_init__ (self ):
127+ self ._check_ids ()
128+ self .metadata = self .metadata if self .metadata else {}
150129
151130 @classmethod
152131 def from_json (cls , payload : dict ):
@@ -178,32 +157,20 @@ def to_payload(self) -> dict:
178157 METADATA_KEY : self .metadata ,
179158 }
180159
181- def __str__ (self ):
182- return str (self .to_payload ())
183-
184160
185161# TODO: Add Generic type for 2D point
162+ @dataclass
186163class PolygonAnnotation (Annotation ):
187- def __init__ (
188- self ,
189- label : str ,
190- vertices : List [Any ],
191- reference_id : Optional [str ] = None ,
192- item_id : Optional [str ] = None ,
193- annotation_id : Optional [str ] = None ,
194- metadata : Optional [Dict ] = None ,
195- ):
196- super ().__init__ ()
197- if bool (reference_id ) == bool (item_id ):
198- raise Exception (
199- "You must specify either a reference_id or an item_id for an annotation."
200- )
201- self .label = label
202- self .vertices = vertices
203- self .reference_id = reference_id
204- self .item_id = item_id
205- self .annotation_id = annotation_id
206- self .metadata = metadata if metadata else {}
164+ label : str
165+ vertices : List [Any ]
166+ reference_id : Optional [str ] = None
167+ item_id : Optional [str ] = None
168+ annotation_id : Optional [str ] = None
169+ metadata : Optional [Dict ] = None
170+
171+ def __post_init__ (self ):
172+ self ._check_ids ()
173+ self .metadata = self .metadata if self .metadata else {}
207174
208175 @classmethod
209176 def from_json (cls , payload : dict ):
@@ -226,6 +193,3 @@ def to_payload(self) -> dict:
226193 ANNOTATION_ID_KEY : self .annotation_id ,
227194 METADATA_KEY : self .metadata ,
228195 }
229-
230- def __str__ (self ):
231- return str (self .to_payload ())
0 commit comments