From b756b6ddf881b76754a492b23d24f14ae926a656 Mon Sep 17 00:00:00 2001 From: adithya32 <163162210+KumarADITHYA123@users.noreply.github.com> Date: Thu, 12 Feb 2026 02:33:36 +0530 Subject: [PATCH] Remove TensorFlow dependency from multimodal image preprocessing Replace tf.image.decode_jpeg with PIL-based JPEG encoding/decoding. This eliminates the TensorFlow dependency from the multimodal module while maintaining identical functionality. - Remove: import tensorflow as tf - Add: import io (for BytesIO) - Replace: TF JPEG ops with PIL equivalent - Maintain: Same output shape and dtype Addresses TODO in gemma/multimodal/image.py:73 --- gemma/multimodal/image.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/gemma/multimodal/image.py b/gemma/multimodal/image.py index 91f1ff4a..dca88fd6 100644 --- a/gemma/multimodal/image.py +++ b/gemma/multimodal/image.py @@ -18,12 +18,12 @@ from collections.abc import Sequence import einops from etils import epath +import io import jax from jax import numpy as jnp from kauldron import typing import numpy as np from PIL import Image -import tensorflow as tf _IMAGE_MEAN = (127.5,) * 3 _IMAGE_STD = (127.5,) * 3 @@ -70,10 +70,12 @@ def pre_process_image( The pre-processed image. """ # all inputs are expected to have been jpeg compressed. - # TODO(eyvinec): we should remove tf dependency. - image = jnp.asarray( - tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3) - ) + # Perform JPEG encode/decode to normalize image (matches original behavior) + pil_image = Image.fromarray(np.uint8(image)) + buffer = io.BytesIO() + pil_image.save(buffer, format="JPEG") + buffer.seek(0) + image = jnp.asarray(np.array(Image.open(buffer).convert("RGB"))) image = jax.image.resize( image, shape=(image_height, image_width, 3),