Code to reproduce:
from gemma import gm
import requests
from io import BytesIO
from PIL import Image
tokenizer = gm.text.Gemma3nTokenizer()
model = gm.nn.Gemma3n_E2B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3N_E2B_IT)
sampler = gm.text.ChatSampler(model=model, params=params, tokenizer=tokenizer)
IMAGE_URL = "https://farm8.staticflickr.com/7149/6712974825_2bcfd17204_z.jpg"
response = requests.get(IMAGE_URL)
img = Image.open(BytesIO(response.content)).convert("RGB")
sampler = gm.text.Sampler(model=model, params=params, tokenizer=tokenizer)
prompt = "<start_of_image>\nAnswer the following question in a single word or short phrase based on the image."
output = sampler.sample(prompt, images=[img])
print(output)
Error details:
TypeCheckError: argument "positions" (jax._src.interpreters.partial_eval.DynamicJaxprTracer) did not match any element in the union:
jaxtyping.Int[Array, '*B L']: is not an instance of jaxtyping.Int[Array, '*B L']
jaxtyping.Int[ndarray, '*B L']: is not an instance of jaxtyping.Int[ndarray, '*B L']
NoneType: is not an instance of NoneType
Function: Gemma3nTransformer.__call__ in /usr/local/lib/python3.12/dist-packages/gemma/gm/nn/gemma3n/_transformer.py:212
Inputs:
self = <class 'gemma.gm.nn.gemma3n._gemma3n.Gemma3n_E2B'>
tokens: Int['*B L'] = i32[1 253]
images: UInt8['*B N H W C'] = ui8[1 1 427 640 3]
positions: Int['*B L'] = i32[1 512]
positions_offset: Int['*B'] = None
cache: types.UnionType[dict[str, dict[str, jax.Array]], NoneType] = <class 'dict'>
attention_mask: Bool['*B L cache_length'] = bool_[1 512 512]
return_last_only: types.UnionType[bool, NoneType] = True
return_hidden_states: types.UnionType[bool, NoneType] = None
Inferred Dims:
{'L': 253, 'N': 1, 'H': 427, 'W': 640, 'C': 3, '*B': (1,)}
This only occurs with the 3N models - Gemma3_4B does not have this issue, for instance.
Occurs on gemma=3.3.0 (the latest version on pypi).
Code to reproduce:
Error details:
This only occurs with the 3N models - Gemma3_4B does not have this issue, for instance.
Occurs on gemma=3.3.0 (the latest version on pypi).