Skip to content

Commit 92e1d7c

Browse files
committed
Fix image preprocessing using the original implementation as reference
1 parent bdc11dc commit 92e1d7c

3 files changed

Lines changed: 74 additions & 21 deletions

File tree

cells2table/models/PaddlePaddle/cell_detection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Iterable, Iterator, Sequence
33

4+
import cv2
45
import numpy as np
56
from numpy.typing import NDArray
67

@@ -54,6 +55,17 @@ def __call__(
5455

5556
return result
5657

58+
def preprocess(self, input: Iterable[NDArray[np.uint8]]) -> list[NDArray[np.float32]]:
59+
blob = cv2.dnn.blobFromImages(
60+
input,
61+
scalefactor=1 / 255.0,
62+
size=self.input_shape,
63+
swapRB=False,
64+
crop=False,
65+
) # ty:ignore[no-matching-overload]
66+
67+
return list(blob)
68+
5769
@classmethod
5870
def postprocess(
5971
cls,

cells2table/models/PaddlePaddle/table_classification.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Iterable, Sequence
33

4+
import cv2
45
import numpy as np
56
from numpy.typing import NDArray
67

@@ -26,9 +27,9 @@ def get_download_options(cls) -> DownloadOptions:
2627

2728
def __call__(self, input: Iterable[NDArray[np.uint8]]) -> list[ClassificationResult]:
2829
logger.debug("Started preprocessing")
29-
input = self.preprocess(input)
30+
images = self.preprocess(input)
3031

31-
input_dict = dict(zip(self.input_names, [input]))
32+
input_dict = dict(zip(self.input_names, [images]))
3233

3334
logger.debug("Done preprocessing")
3435
logger.debug("Started running the model")
@@ -44,6 +45,65 @@ def __call__(self, input: Iterable[NDArray[np.uint8]]) -> list[ClassificationRes
4445

4546
return result
4647

48+
def preprocess(self, input: Iterable[NDArray[np.uint8]]) -> list[NDArray[np.float32]]:
49+
"""PP-LCNet image preprocessing pipeline.
50+
51+
Args:
52+
input: iterable of HxWxC uint8 images (C=3, assumed RGB).
53+
54+
Output:
55+
list of CxHxW float32 tensors (BGR order), normalized with PP-LCNet mean/std.
56+
"""
57+
resize_short = 256 # shorter edge after resize
58+
crop_size = 224 # center crop size
59+
mean = np.asarray([0.406, 0.456, 0.485], dtype=np.float32) # RGB mean
60+
std = np.asarray([0.225, 0.224, 0.229], dtype=np.float32) # RGB std
61+
rescale_factor = 1.0 / 255.0 # uint8 -> [0,1]
62+
63+
out: list[NDArray[np.float32]] = []
64+
65+
for img in input:
66+
# Validate and coerce to expected dtype/layout (HWC, uint8, 3 channels)
67+
if img.ndim != 3 or img.shape[2] != 3:
68+
raise ValueError(f"Expected HxWx3 image, got shape={img.shape}")
69+
if img.dtype != np.uint8:
70+
raise ValueError(f"Expected uint8 image, got dtype={img.dtype}")
71+
72+
h, w = img.shape[:2]
73+
74+
# Resize while preserving aspect ratio using the shorter edge as reference
75+
scale = resize_short / float(min(h, w))
76+
new_h = int(round(h * scale))
77+
new_w = int(round(w * scale))
78+
79+
# Perform the resize (OpenCV expects size as (width, height))
80+
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
81+
82+
# Center-crop to crop_size x crop_size (assumes resized dims are >= crop_size)
83+
if new_h < crop_size or new_w < crop_size:
84+
raise ValueError(
85+
f"Resized image too small for center crop: resized={new_h}x{new_w}, crop={crop_size}"
86+
)
87+
top = (new_h - crop_size) // 2
88+
left = (new_w - crop_size) // 2
89+
cropped = resized[top : top + crop_size, left : left + crop_size, :]
90+
91+
# Convert to float32 and rescale to [0,1]
92+
x = cropped.astype(np.float32) * rescale_factor
93+
94+
# Normalize per channel in RGB space: (x - mean) / std
95+
x = (x - mean) / std
96+
97+
# Convert RGB -> BGR
98+
x = x[..., ::-1]
99+
100+
# Convert HWC -> CHW
101+
x = np.transpose(x, (2, 0, 1)).astype(np.float32, copy=False)
102+
103+
out.append(x)
104+
105+
return out
106+
47107
@classmethod
48108
def postprocess(cls, pred: Sequence[Sequence[float]]) -> list[ClassificationResult]:
49109
return [ClassificationResult(cls.classes[np.argmax(p)], max(p)) for p in pred]

cells2table/models/runtimes/onnx.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
11
from abc import ABC, abstractmethod
22
from pathlib import Path
3-
from typing import Iterable
43

5-
import cv2
6-
import numpy as np
74
import onnxruntime as ort
8-
from numpy.typing import NDArray
95

106
from cells2table.models.tasks.base import BaseModel
117

128

139
class OnnxModel(BaseModel, ABC):
1410
"""Base interface for ONNX models."""
1511

16-
scale = 1 / 255.0
17-
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
18-
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
19-
2012
@classmethod
2113
@abstractmethod
2214
def get_onnx_path(self) -> str:
@@ -49,14 +41,3 @@ def input_names(self):
4941
@property
5042
def output_names(self):
5143
return [v.name for v in self.session.get_outputs()]
52-
53-
def preprocess(self, input: Iterable[NDArray[np.uint8]]) -> list[NDArray[np.uint8]]:
54-
output = []
55-
56-
for img in input:
57-
img = cv2.resize(img, dsize=self.input_shape, interpolation=cv2.INTER_LANCZOS4)
58-
img = (img.astype(np.float32) * self.scale - self.mean) / self.std # Normalize
59-
img = img.transpose(2, 0, 1) # HWC to CHW
60-
output.append(img)
61-
62-
return output

0 commit comments

Comments
 (0)