diff --git a/gemma/multimodal/image.py b/gemma/multimodal/image.py index 91f1ff4a..dca88fd6 100644 --- a/gemma/multimodal/image.py +++ b/gemma/multimodal/image.py @@ -18,12 +18,12 @@ from collections.abc import Sequence import einops from etils import epath +import io import jax from jax import numpy as jnp from kauldron import typing import numpy as np from PIL import Image -import tensorflow as tf _IMAGE_MEAN = (127.5,) * 3 _IMAGE_STD = (127.5,) * 3 @@ -70,10 +70,12 @@ def pre_process_image( The pre-processed image. """ # all inputs are expected to have been jpeg compressed. - # TODO(eyvinec): we should remove tf dependency. - image = jnp.asarray( - tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3) - ) + # Perform JPEG encode/decode to normalize image (matches original behavior) + pil_image = Image.fromarray(np.uint8(image)) + buffer = io.BytesIO() + pil_image.save(buffer, format="JPEG") + buffer.seek(0) + image = jnp.asarray(np.array(Image.open(buffer).convert("RGB"))) image = jax.image.resize( image, shape=(image_height, image_width, 3),