@@ -66,13 +66,53 @@ def __init__(
6666 preload : bool = False ,
6767 ) -> None :
6868 super ().__init__ (inference_adapter , configuration , preload )
69- self ._check_io_number (1 , 1 )
69+ self ._check_io_number (1 , ( 1 , 4 ) )
7070 self .normalization_scale : float
7171 self .image_threshold : float
7272 self .pixel_threshold : float
7373 self .task : str
7474 self .labels : list [str ]
7575
76+ def preprocess (self , inputs : np .ndarray ) -> list [dict ]:
77+ """Data preprocess method for Anomalib models.
78+
79+ Anomalib models typically expect inputs in [0,1] range as float32.
80+ """
81+ original_shape = inputs .shape
82+
83+ if self ._is_dynamic :
84+ h , w , c = inputs .shape
85+ resized_shape = (w , h , c )
86+
87+ # For anomalib models, convert to float32 and normalize to [0,1] if needed
88+ if inputs .dtype == np .uint8 :
89+ processed_image = inputs .astype (np .float32 ) / 255.0
90+ else :
91+ processed_image = inputs .astype (np .float32 )
92+
93+ # Apply layout change but skip InputTransform (which might apply wrong normalization)
94+ processed_image = self ._change_layout (processed_image )
95+ else :
96+ resized_shape = (self .w , self .h , self .c )
97+ # For fixed models, use standard preprocessing
98+ if self .embedded_processing :
99+ processed_image = inputs [None ]
100+ else :
101+ # Convert to float32 and normalize for anomalib
102+ if inputs .dtype == np .uint8 :
103+ processed_image = inputs .astype (np .float32 ) / 255.0
104+ else :
105+ processed_image = inputs .astype (np .float32 )
106+ processed_image = self ._change_layout (processed_image )
107+
108+ return [
109+ {self .image_blob_name : processed_image },
110+ {
111+ "original_shape" : original_shape ,
112+ "resized_shape" : resized_shape ,
113+ },
114+ ]
115+
76116 def postprocess (self , outputs : dict [str , np .ndarray ], meta : dict [str , Any ]) -> AnomalyResult :
77117 """Post-processes the outputs and returns the results.
78118
@@ -87,48 +127,59 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
87127 pred_label : str | None = None
88128 pred_mask : np .ndarray | None = None
89129 pred_boxes : np .ndarray | None = None
90- predictions = outputs [next (iter (self .outputs ))]
91130
92- if len (predictions .shape ) == 1 :
93- pred_score = predictions
94- else :
95- anomaly_map = predictions .squeeze ()
96- pred_score = anomaly_map .reshape (- 1 ).max ()
131+ anomalib_keys = ['pred_score' , 'pred_label' , 'pred_mask' , 'anomaly_map' ]
132+ if not all (key in outputs for key in anomalib_keys ):
133+ predictions = outputs [next (iter (self .outputs ))]
97134
98- pred_label = self .labels [1 ] if pred_score > self .image_threshold else self .labels [0 ]
135+ if len (predictions .shape ) == 1 :
136+ npred_score = predictions
137+ else :
138+ anomaly_map = predictions .squeeze ()
139+ npred_score = anomaly_map .reshape (- 1 ).max ()
99140
100- assert anomaly_map is not None
101- pred_mask = (anomaly_map >= self .pixel_threshold ).astype (np .uint8 )
102- anomaly_map = self ._normalize (anomaly_map , self .pixel_threshold )
103- anomaly_map *= 255
104- anomaly_map = np .round (anomaly_map ).astype (np .uint8 )
105- pred_mask = cv2 .resize (
106- pred_mask ,
107- (meta ["original_shape" ][1 ], meta ["original_shape" ][0 ]),
108- )
141+ pred_label = self .labels [1 ] if npred_score > self .image_threshold else self .labels [0 ]
109142
110- # normalize
111- pred_score = self ._normalize (pred_score , self .image_threshold )
143+ assert anomaly_map is not None
144+ pred_mask = (anomaly_map >= self .pixel_threshold ).astype (np .uint8 )
145+ anomaly_map = self ._normalize (anomaly_map , self .pixel_threshold )
112146
113- if pred_label == self .labels [0 ]: # normal
114- pred_score = 1 - pred_score # Score of normal is 1 - score of anomaly
147+ # normalize
148+ npred_score = self ._normalize (npred_score , self .image_threshold )
149+
150+ if pred_label == self .labels [0 ]: # normal
151+ npred_score = 1 - npred_score # Score of normal is 1 - score of anomaly
152+ pred_score = npred_score .item ()
153+ else :
154+ pred_score = outputs ['pred_score' ].item ()
155+ pred_label = str (outputs ['pred_label' ].item ())
156+ anomaly_map = outputs ['anomaly_map' ].squeeze ()
157+ pred_mask = outputs ['pred_mask' ].squeeze ().astype (np .uint8 )
158+
159+ anomaly_map *= 255
160+ anomaly_map = np .round (anomaly_map ).astype (np .uint8 )
115161
116- # resize outputs
117162 if anomaly_map is not None :
118163 anomaly_map = cv2 .resize (
119164 anomaly_map ,
120165 (meta ["original_shape" ][1 ], meta ["original_shape" ][0 ]),
121166 )
122167
168+ pred_mask = cv2 .resize (
169+ pred_mask ,
170+ (meta ["original_shape" ][1 ], meta ["original_shape" ][0 ]),
171+ )
172+
123173 if self .task == "detection" :
124174 pred_boxes = self ._get_boxes (pred_mask )
125175
176+
126177 return AnomalyResult (
127178 anomaly_map = anomaly_map ,
128179 pred_boxes = pred_boxes ,
129180 pred_label = pred_label ,
130181 pred_mask = pred_mask ,
131- pred_score = pred_score . item () ,
182+ pred_score = pred_score ,
132183 )
133184
134185 @classmethod
0 commit comments