|
| 1 | +""" |
| 2 | +Image preprocessing for the multimodal data pipeline. |
| 3 | +
|
| 4 | +Converts a PIL image (or HWC ``numpy`` array) into the |
| 5 | +``(n_patches, patch_dim)`` patch tensor + ``(n_patches,)`` validity mask |
| 6 | +that the vision tower expects. |
| 7 | +
|
| 8 | +Patch flatten convention is **channel-first** (``[c, kh, kw]``) — matching the |
| 9 | +HuggingFace ``Conv2d`` patch embedding semantics our parity tests verified |
| 10 | +against ``openai/clip-vit-large-patch14-336`` and |
| 11 | +``google/siglip-so400m-patch14-384``. |
| 12 | +""" |
| 13 | + |
| 14 | +from dataclasses import dataclass |
| 15 | +from typing import Tuple |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +from ...config import Config, StrEnum |
| 20 | + |
| 21 | +__all__ = [ |
| 22 | + "NormalizeStyle", |
| 23 | + "ImagePreprocessorConfig", |
| 24 | + "ImagePreprocessor", |
| 25 | +] |
| 26 | + |
| 27 | + |
| 28 | +# --------------------------------------------------------------------------- |
| 29 | +# Standard mean / std constants (RGB, [0, 1] scale) |
| 30 | +# --------------------------------------------------------------------------- |
| 31 | + |
| 32 | +OPENAI_CLIP_MEAN: Tuple[float, float, float] = (0.48145466, 0.4578275, 0.40821073) |
| 33 | +OPENAI_CLIP_STD: Tuple[float, float, float] = (0.26862954, 0.26130258, 0.27577711) |
| 34 | + |
| 35 | + |
| 36 | +class NormalizeStyle(StrEnum): |
| 37 | + """How to normalize pixel values before patchification.""" |
| 38 | + |
| 39 | + siglip = "siglip" |
| 40 | + """SigLIP / SigLIP2: scale ``[0, 1] → [-1, 1]`` via ``image * 2 - 1``.""" |
| 41 | + |
| 42 | + openai = "openai" |
| 43 | + """OpenAI CLIP: ``(image - OPENAI_CLIP_MEAN) / OPENAI_CLIP_STD``.""" |
| 44 | + |
| 45 | + |
| 46 | +@dataclass |
| 47 | +class ImagePreprocessorConfig(Config): |
| 48 | + """ |
| 49 | + Configuration for :class:`ImagePreprocessor`. |
| 50 | + """ |
| 51 | + |
| 52 | + patch_size: int = 14 |
| 53 | + """Pixel size of each square patch. Must divide every crop dimension.""" |
| 54 | + |
| 55 | + normalize: NormalizeStyle = NormalizeStyle.siglip |
| 56 | + """Normalization style; pick the one matching your vision encoder.""" |
| 57 | + |
| 58 | + pad_value: float = 0.0 |
| 59 | + """Pixel value (in the ``[0, 1]`` scale, pre-normalize) used to pad the |
| 60 | + image when its aspect ratio differs from the target crop.""" |
| 61 | + |
| 62 | + def build(self) -> "ImagePreprocessor": |
| 63 | + """Instantiate an :class:`ImagePreprocessor` for this config.""" |
| 64 | + return ImagePreprocessor(self) |
| 65 | + |
| 66 | + |
| 67 | +class ImagePreprocessor: |
| 68 | + """Resize, normalize, and patchify a single image.""" |
| 69 | + |
| 70 | + def __init__(self, cfg: ImagePreprocessorConfig): |
| 71 | + self.cfg = cfg |
| 72 | + |
| 73 | + # ------------------------------------------------------------------ |
| 74 | + # Resize + pad |
| 75 | + # ------------------------------------------------------------------ |
| 76 | + |
| 77 | + @staticmethod |
| 78 | + def _to_float_hwc(image) -> np.ndarray: |
| 79 | + """Coerce input to a float32 HWC array in [0, 1].""" |
| 80 | + if isinstance(image, np.ndarray): |
| 81 | + arr = image |
| 82 | + else: |
| 83 | + # PIL.Image or anything with .convert / np.asarray support |
| 84 | + arr = np.asarray(image.convert("RGB") if hasattr(image, "convert") else image) |
| 85 | + if arr.dtype != np.float32: |
| 86 | + arr = arr.astype(np.float32) |
| 87 | + if arr.max() > 1.5: # likely uint8 range |
| 88 | + arr = arr / 255.0 |
| 89 | + if arr.ndim == 2: |
| 90 | + arr = np.stack([arr] * 3, axis=-1) |
| 91 | + if arr.shape[-1] != 3: |
| 92 | + raise ValueError(f"expected RGB image, got shape {arr.shape}") |
| 93 | + return arr |
| 94 | + |
| 95 | + def resize_and_pad( |
| 96 | + self, |
| 97 | + image, |
| 98 | + target_size: Tuple[int, int], |
| 99 | + ) -> Tuple[np.ndarray, np.ndarray]: |
| 100 | + """Resize aspect-preserving to fit ``target_size``, then pad. |
| 101 | +
|
| 102 | + :param image: PIL image, HWC ``np.uint8`` / ``np.float32`` array. |
| 103 | + :param target_size: ``(target_h, target_w)``. |
| 104 | + :returns: ``(image_arr, mask_arr)``: |
| 105 | + - ``image_arr``: ``(target_h, target_w, 3)`` ``float32`` in |
| 106 | + ``[0, 1]`` (pre-normalize). Padded regions equal ``pad_value``. |
| 107 | + - ``mask_arr``: ``(target_h, target_w)`` ``float32`` with ``1.0`` |
| 108 | + where the image content lives and ``0.0`` where it was padded. |
| 109 | + """ |
| 110 | + cfg = self.cfg |
| 111 | + arr = self._to_float_hwc(image) |
| 112 | + src_h, src_w, _ = arr.shape |
| 113 | + tgt_h, tgt_w = target_size |
| 114 | + |
| 115 | + # Scale to fit (aspect-preserving). Use the dimension that's the tighter limit. |
| 116 | + scale = min(tgt_h / src_h, tgt_w / src_w) |
| 117 | + new_h = max(1, int(round(src_h * scale))) |
| 118 | + new_w = max(1, int(round(src_w * scale))) |
| 119 | + |
| 120 | + # Bilinear resize via numpy. Keeping it dependency-free. |
| 121 | + resized = _bilinear_resize(arr, new_h, new_w) |
| 122 | + |
| 123 | + # Pad to target. |
| 124 | + out = np.full((tgt_h, tgt_w, 3), cfg.pad_value, dtype=np.float32) |
| 125 | + mask = np.zeros((tgt_h, tgt_w), dtype=np.float32) |
| 126 | + pad_top = (tgt_h - new_h) // 2 |
| 127 | + pad_left = (tgt_w - new_w) // 2 |
| 128 | + out[pad_top : pad_top + new_h, pad_left : pad_left + new_w] = resized |
| 129 | + mask[pad_top : pad_top + new_h, pad_left : pad_left + new_w] = 1.0 |
| 130 | + return out, mask |
| 131 | + |
| 132 | + # ------------------------------------------------------------------ |
| 133 | + # Normalize |
| 134 | + # ------------------------------------------------------------------ |
| 135 | + |
| 136 | + def normalize(self, image: np.ndarray) -> np.ndarray: |
| 137 | + """Apply the configured normalization (in-place safe).""" |
| 138 | + cfg = self.cfg |
| 139 | + if cfg.normalize == NormalizeStyle.siglip: |
| 140 | + return image * 2.0 - 1.0 |
| 141 | + elif cfg.normalize == NormalizeStyle.openai: |
| 142 | + mean = np.asarray(OPENAI_CLIP_MEAN, dtype=np.float32)[None, None, :] |
| 143 | + std = np.asarray(OPENAI_CLIP_STD, dtype=np.float32)[None, None, :] |
| 144 | + return (image - mean) / std |
| 145 | + else: |
| 146 | + raise NotImplementedError(f"unsupported normalize style: {cfg.normalize}") |
| 147 | + |
| 148 | + # ------------------------------------------------------------------ |
| 149 | + # Patchify (channel-first flatten) |
| 150 | + # ------------------------------------------------------------------ |
| 151 | + |
| 152 | + def patchify(self, image: np.ndarray) -> np.ndarray: |
| 153 | + """Reshape ``(H, W, 3)`` to ``(n_patches, 3 * p * p)`` with C-first flatten. |
| 154 | +
|
| 155 | + The output order for each patch matches the natural flatten of a |
| 156 | + HuggingFace ``Conv2d(kernel=p, stride=p)`` weight reshaped via |
| 157 | + ``.reshape(D, -1)``: index ``i = c * p * p + kh * p + kw`` selects |
| 158 | + pixel ``(c, kh, kw)`` within the patch. |
| 159 | + """ |
| 160 | + p = self.cfg.patch_size |
| 161 | + h, w, c = image.shape |
| 162 | + if h % p != 0 or w % p != 0: |
| 163 | + raise ValueError(f"image size ({h}, {w}) is not divisible by patch_size {p}") |
| 164 | + # (H, W, C) → (h_patches, p, w_patches, p, C) |
| 165 | + x = image.reshape(h // p, p, w // p, p, c) |
| 166 | + # → (h_patches, w_patches, C, p, p) so flatten is C-first per patch. |
| 167 | + x = x.transpose(0, 2, 4, 1, 3) |
| 168 | + return x.reshape((h // p) * (w // p), c * p * p).astype(np.float32, copy=False) |
| 169 | + |
| 170 | + def patchify_mask(self, mask: np.ndarray) -> np.ndarray: |
| 171 | + """Per-pixel mask ``(H, W)`` → per-patch coverage ``(n_patches,)``. |
| 172 | +
|
| 173 | + Returns the mean of the per-pixel mask within each patch, so partially |
| 174 | + padded patches receive a fractional weight in ``(0, 1)``. |
| 175 | + """ |
| 176 | + p = self.cfg.patch_size |
| 177 | + h, w = mask.shape |
| 178 | + if h % p != 0 or w % p != 0: |
| 179 | + raise ValueError(f"mask size ({h}, {w}) is not divisible by patch_size {p}") |
| 180 | + x = ( |
| 181 | + mask.reshape(h // p, p, w // p, p) |
| 182 | + .transpose(0, 2, 1, 3) |
| 183 | + .reshape((h // p) * (w // p), p * p) |
| 184 | + ) |
| 185 | + return x.mean(axis=-1).astype(np.float32, copy=False) |
| 186 | + |
| 187 | + # ------------------------------------------------------------------ |
| 188 | + # Convenience: full pipeline for a single crop |
| 189 | + # ------------------------------------------------------------------ |
| 190 | + |
| 191 | + def preprocess( |
| 192 | + self, |
| 193 | + image, |
| 194 | + target_size: Tuple[int, int], |
| 195 | + ) -> Tuple[np.ndarray, np.ndarray]: |
| 196 | + """Resize, normalize, and patchify a single image. |
| 197 | +
|
| 198 | + :returns: ``(patches, mask)`` with shapes ``(n_patches, 3 * p * p)`` |
| 199 | + and ``(n_patches,)`` respectively. |
| 200 | + """ |
| 201 | + image_arr, mask_arr = self.resize_and_pad(image, target_size) |
| 202 | + image_arr = self.normalize(image_arr) |
| 203 | + return self.patchify(image_arr), self.patchify_mask(mask_arr) |
| 204 | + |
| 205 | + |
| 206 | +# --------------------------------------------------------------------------- |
| 207 | +# Dependency-free bilinear resize |
| 208 | +# --------------------------------------------------------------------------- |
| 209 | + |
| 210 | + |
| 211 | +def _bilinear_resize(image: np.ndarray, new_h: int, new_w: int) -> np.ndarray: |
| 212 | + """Bilinear resize of an HWC float32 array. |
| 213 | +
|
| 214 | + Implemented in pure numpy so the preprocessor has no torchvision or PIL |
| 215 | + dependency at the resize step. Aligns corners the same way ``PIL.Image |
| 216 | + .resize(..., BILINEAR)`` does (half-pixel offsets). |
| 217 | + """ |
| 218 | + src_h, src_w, c = image.shape |
| 219 | + if src_h == new_h and src_w == new_w: |
| 220 | + return image.astype(np.float32, copy=False) |
| 221 | + |
| 222 | + # Compute source coordinates of each output pixel (half-pixel sampling). |
| 223 | + y = (np.arange(new_h, dtype=np.float32) + 0.5) * (src_h / new_h) - 0.5 |
| 224 | + x = (np.arange(new_w, dtype=np.float32) + 0.5) * (src_w / new_w) - 0.5 |
| 225 | + y0 = np.clip(np.floor(y).astype(np.int64), 0, src_h - 1) |
| 226 | + x0 = np.clip(np.floor(x).astype(np.int64), 0, src_w - 1) |
| 227 | + y1 = np.clip(y0 + 1, 0, src_h - 1) |
| 228 | + x1 = np.clip(x0 + 1, 0, src_w - 1) |
| 229 | + wy = np.clip(y - y0, 0.0, 1.0) |
| 230 | + wx = np.clip(x - x0, 0.0, 1.0) |
| 231 | + |
| 232 | + # Gather and blend the four neighbors. |
| 233 | + top_left = image[y0[:, None], x0[None, :], :] |
| 234 | + top_right = image[y0[:, None], x1[None, :], :] |
| 235 | + bot_left = image[y1[:, None], x0[None, :], :] |
| 236 | + bot_right = image[y1[:, None], x1[None, :], :] |
| 237 | + |
| 238 | + wy_2d = wy[:, None, None] |
| 239 | + wx_2d = wx[None, :, None] |
| 240 | + top = top_left * (1 - wx_2d) + top_right * wx_2d |
| 241 | + bot = bot_left * (1 - wx_2d) + bot_right * wx_2d |
| 242 | + return (top * (1 - wy_2d) + bot * wy_2d).astype(np.float32) |
0 commit comments