From e5faa4504bf41a2a1d19aab1c3c31c5404cf385c Mon Sep 17 00:00:00 2001 From: Sheshank Singh <148331907+Sheshank-singh@users.noreply.github.com> Date: Thu, 19 Feb 2026 16:05:39 +0530 Subject: [PATCH] fix: remove TensorFlow dependency from image processing Replace TensorFlow's JPEG encode/decode operations with pure JAX/PIL implementation. This change: - Removes unnecessary TensorFlow dependency from gemma/multimodal/image.py - Improves consistency by using only JAX/NumPy for ML operations - Uses PIL (already imported) for JPEG encoding/decoding instead - Maintains identical functionality and behavior Implementation details: - Clips image to uint8 range (0-255) for JPEG compatibility - Uses PIL to encode image to JPEG bytes buffer - Decodes JPEG bytes back to RGB array for standardization - Converts back to float32 JAX array for subsequent processing This closes the TODO comment that requested TensorFlow dependency removal. The PIL library is part of Python standard ecosystem and already used elsewhere in the file for image loading. Files Modified: - gemma/multimodal/image.py --- gemma/multimodal/image.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/gemma/multimodal/image.py b/gemma/multimodal/image.py index ce4182e0..3c338e16 100644 --- a/gemma/multimodal/image.py +++ b/gemma/multimodal/image.py @@ -16,6 +16,7 @@ from __future__ import annotations from collections.abc import Sequence +import io import einops from etils import epath import jax @@ -23,7 +24,6 @@ from kauldron import typing import numpy as np from PIL import Image -import tensorflow as tf _IMAGE_MEAN = (127.5,) * 3 _IMAGE_STD = (127.5,) * 3 @@ -69,11 +69,17 @@ def pre_process_image( Returns: The pre-processed image. """ - # all inputs are expected to have been jpeg compressed. - # TODO(eyvinec): we should remove tf dependency. - image = jnp.asarray( - tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3) - ) + # Normalize image to uint8 range for JPEG encoding + image_uint8 = jnp.clip(image, 0, 255).astype(jnp.uint8) + + # Encode and decode with JPEG via PIL for standardization + pil_image = Image.fromarray(np.array(image_uint8), mode='RGB') + jpeg_buffer = io.BytesIO() + pil_image.save(jpeg_buffer, format='JPEG') + jpeg_buffer.seek(0) + image = np.array(Image.open(jpeg_buffer).convert('RGB')) + image = jnp.asarray(image, dtype=jnp.float32) + image = jax.image.resize( image, shape=(image_height, image_width, 3),