Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions gemma/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from collections.abc import Sequence
import einops
from etils import epath
import io
import jax
from jax import numpy as jnp
from kauldron import typing
import numpy as np
from PIL import Image
import tensorflow as tf

_IMAGE_MEAN = (127.5,) * 3
_IMAGE_STD = (127.5,) * 3
Expand Down Expand Up @@ -70,10 +70,12 @@ def pre_process_image(
The pre-processed image.
"""
# all inputs are expected to have been jpeg compressed.
# TODO(eyvinec): we should remove tf dependency.
image = jnp.asarray(
tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3)
)
# Perform JPEG encode/decode to normalize image (matches original behavior)
pil_image = Image.fromarray(np.uint8(image))
buffer = io.BytesIO()
pil_image.save(buffer, format="JPEG")
buffer.seek(0)
image = jnp.asarray(np.array(Image.open(buffer).convert("RGB")))
image = jax.image.resize(
image,
shape=(image_height, image_width, 3),
Expand Down