Skip to content

Commit 86d2537

Browse files
committed
add anomalib support
1 parent 079e7b6 commit 86d2537

3 files changed

Lines changed: 124 additions & 31 deletions

File tree

src/model_api/models/anomaly.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,53 @@ def __init__(
6666
preload: bool = False,
6767
) -> None:
6868
super().__init__(inference_adapter, configuration, preload)
69-
self._check_io_number(1, 1)
69+
self._check_io_number(1, (1, 4))
7070
self.normalization_scale: float
7171
self.image_threshold: float
7272
self.pixel_threshold: float
7373
self.task: str
7474
self.labels: list[str]
7575

76+
def preprocess(self, inputs: np.ndarray) -> list[dict]:
77+
"""Data preprocess method for Anomalib models.
78+
79+
Anomalib models typically expect inputs in [0,1] range as float32.
80+
"""
81+
original_shape = inputs.shape
82+
83+
if self._is_dynamic:
84+
h, w, c = inputs.shape
85+
resized_shape = (w, h, c)
86+
87+
# For anomalib models, convert to float32 and normalize to [0,1] if needed
88+
if inputs.dtype == np.uint8:
89+
processed_image = inputs.astype(np.float32) / 255.0
90+
else:
91+
processed_image = inputs.astype(np.float32)
92+
93+
# Apply layout change but skip InputTransform (which might apply wrong normalization)
94+
processed_image = self._change_layout(processed_image)
95+
else:
96+
resized_shape = (self.w, self.h, self.c)
97+
# For fixed models, use standard preprocessing
98+
if self.embedded_processing:
99+
processed_image = inputs[None]
100+
else:
101+
# Convert to float32 and normalize for anomalib
102+
if inputs.dtype == np.uint8:
103+
processed_image = inputs.astype(np.float32) / 255.0
104+
else:
105+
processed_image = inputs.astype(np.float32)
106+
processed_image = self._change_layout(processed_image)
107+
108+
return [
109+
{self.image_blob_name: processed_image},
110+
{
111+
"original_shape": original_shape,
112+
"resized_shape": resized_shape,
113+
},
114+
]
115+
76116
def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> AnomalyResult:
77117
"""Post-processes the outputs and returns the results.
78118
@@ -87,48 +127,59 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
87127
pred_label: str | None = None
88128
pred_mask: np.ndarray | None = None
89129
pred_boxes: np.ndarray | None = None
90-
predictions = outputs[next(iter(self.outputs))]
91130

92-
if len(predictions.shape) == 1:
93-
pred_score = predictions
94-
else:
95-
anomaly_map = predictions.squeeze()
96-
pred_score = anomaly_map.reshape(-1).max()
131+
anomalib_keys = ['pred_score', 'pred_label', 'pred_mask', 'anomaly_map']
132+
if not all(key in outputs for key in anomalib_keys):
133+
predictions = outputs[next(iter(self.outputs))]
97134

98-
pred_label = self.labels[1] if pred_score > self.image_threshold else self.labels[0]
135+
if len(predictions.shape) == 1:
136+
npred_score = predictions
137+
else:
138+
anomaly_map = predictions.squeeze()
139+
npred_score = anomaly_map.reshape(-1).max()
99140

100-
assert anomaly_map is not None
101-
pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8)
102-
anomaly_map = self._normalize(anomaly_map, self.pixel_threshold)
103-
anomaly_map *= 255
104-
anomaly_map = np.round(anomaly_map).astype(np.uint8)
105-
pred_mask = cv2.resize(
106-
pred_mask,
107-
(meta["original_shape"][1], meta["original_shape"][0]),
108-
)
141+
pred_label = self.labels[1] if npred_score > self.image_threshold else self.labels[0]
109142

110-
# normalize
111-
pred_score = self._normalize(pred_score, self.image_threshold)
143+
assert anomaly_map is not None
144+
pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8)
145+
anomaly_map = self._normalize(anomaly_map, self.pixel_threshold)
112146

113-
if pred_label == self.labels[0]: # normal
114-
pred_score = 1 - pred_score # Score of normal is 1 - score of anomaly
147+
# normalize
148+
npred_score = self._normalize(npred_score, self.image_threshold)
149+
150+
if pred_label == self.labels[0]: # normal
151+
npred_score = 1 - npred_score # Score of normal is 1 - score of anomaly
152+
pred_score=npred_score.item()
153+
else:
154+
pred_score = outputs['pred_score'].item()
155+
pred_label = str(outputs['pred_label'].item())
156+
anomaly_map = outputs['anomaly_map'].squeeze()
157+
pred_mask = outputs['pred_mask'].squeeze().astype(np.uint8)
158+
159+
anomaly_map *= 255
160+
anomaly_map = np.round(anomaly_map).astype(np.uint8)
115161

116-
# resize outputs
117162
if anomaly_map is not None:
118163
anomaly_map = cv2.resize(
119164
anomaly_map,
120165
(meta["original_shape"][1], meta["original_shape"][0]),
121166
)
122167

168+
pred_mask = cv2.resize(
169+
pred_mask,
170+
(meta["original_shape"][1], meta["original_shape"][0]),
171+
)
172+
123173
if self.task == "detection":
124174
pred_boxes = self._get_boxes(pred_mask)
125175

176+
126177
return AnomalyResult(
127178
anomaly_map=anomaly_map,
128179
pred_boxes=pred_boxes,
129180
pred_label=pred_label,
130181
pred_mask=pred_mask,
131-
pred_score=pred_score.item(),
182+
pred_score=pred_score,
132183
)
133184

134185
@classmethod

src/model_api/models/image_model.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
7373
self.n, self.c, self.h, self.w = self.inputs[self.image_blob_name].shape
7474
else:
7575
self.n, self.h, self.w, self.c = self.inputs[self.image_blob_name].shape
76+
77+
self._is_dynamic = False
78+
if self.h == -1 or self.w == -1 or self.n == -1:
79+
self._is_dynamic = True
80+
if self.n == -1:
81+
self.n = 1
82+
7683
self.resize = RESIZE_TYPES[self.resize_type]
7784
self.input_transform = InputTransform(
7885
self.reverse_input_channels,
@@ -83,7 +90,7 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
8390
layout = self.inputs[self.image_blob_name].layout
8491
if self.embedded_processing:
8592
self.h, self.w = self.orig_height, self.orig_width
86-
else:
93+
elif not self._is_dynamic:
8794
inference_adapter.embed_preprocessing(
8895
layout=layout,
8996
resize_mode=self.resize_type,
@@ -213,11 +220,24 @@ def preprocess(self, inputs: np.ndarray) -> list[dict]:
213220
}
214221
- the input metadata, which might be used in `postprocess` method
215222
"""
223+
if self._is_dynamic:
224+
h, w, c = inputs.shape
225+
resized_shape = (w, h, c)
226+
processed_image = self.input_transform(inputs)
227+
processed_image = self._change_layout(processed_image)
228+
else:
229+
resized_shape = (self.w, self.h, self.c)
230+
if self.embedded_processing:
231+
processed_image = inputs[None]
232+
else:
233+
processed_image = self.input_transform(inputs)
234+
processed_image = self._change_layout(processed_image)
235+
216236
return [
217-
{self.image_blob_name: inputs[None]},
237+
{self.image_blob_name: processed_image},
218238
{
219239
"original_shape": inputs.shape,
220-
"resized_shape": (self.w, self.h, self.c),
240+
"resized_shape": resized_shape,
221241
},
222242
]
223243

@@ -230,9 +250,13 @@ def _change_layout(self, image: np.ndarray) -> np.ndarray:
230250
Returns:
231251
- the image with layout aligned with the model layout
232252
"""
253+
h, w, c = image.shape if self._is_dynamic else (self.h, self.w, self.c)
254+
255+
# For fixed models, use the predefined dimensions
233256
if self.nchw_layout:
234257
image = image.transpose((2, 0, 1)) # HWC->CHW
235-
image = image.reshape((1, self.c, self.h, self.w))
258+
image = image.reshape((1, c, h, w))
236259
else:
237-
image = image.reshape((1, self.h, self.w, self.c))
260+
image = image.reshape((1, h, w, c))
261+
238262
return image

src/model_api/models/model.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,30 @@ def create_model(
195195
cache_dir=cache_dir,
196196
)
197197
if model_type is None:
198-
model_type = inference_adapter.get_rt_info(
199-
["model_info", "model_type"],
200-
).astype(str)
198+
try:
199+
model_type = inference_adapter.get_rt_info(
200+
["model_info", "model_type"],
201+
).astype(str)
202+
except RuntimeError:
203+
model_type = cls.detect_model_type(inference_adapter)
201204
Model = cls.get_model_class(model_type)
202205
return Model(inference_adapter, configuration, preload)
203206

207+
@classmethod
208+
def detect_model_type(cls, inference_adapter) -> str:
209+
"""Detects model type on available information"""
210+
input_layers = inference_adapter.get_input_layers()
211+
output_layers = inference_adapter.get_output_layers()
212+
213+
# Check for Anomalib model pattern: 1 input and specific output layer names
214+
if len(input_layers) == 1 and len(output_layers) == 4:
215+
expected_outputs = {'pred_score', 'pred_label', 'anomaly_map', 'pred_mask'}
216+
actual_outputs = set(output_layers.keys())
217+
if expected_outputs == actual_outputs:
218+
return "AnomalyDetection"
219+
220+
return 'uknown'
221+
204222
@classmethod
205223
def get_subclasses(cls) -> list[Any]:
206224
"""Retrieves all the subclasses of the model class given."""

0 commit comments

Comments
 (0)