Skip to content

Commit d02e676

Browse files
committed
ViT with matching image embedding
1 parent 9ea5211 commit d02e676

8 files changed

Lines changed: 438 additions & 10 deletions

File tree

MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,9 @@ temperature_tuning: False
698698

699699
# Multimodal flags
700700
use_multimodal: False
701+
freeze_vision_encoder_params: True
702+
dtype_mm: "float32" # Data type for multimodal model's vision encoder
703+
remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options.
701704
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
702705
image_path: "" # Local image path used for decoding
703706

MaxText/convert_gemma3_chkpt.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ def nest_params(params: Params) -> Params:
4545
return nested_params
4646

4747

48+
def rename_nested_keys(data, old_key, new_key):
49+
"""
50+
Recursively renames keys in a nested dictionary.
51+
Args:
52+
data (dict): The nested dictionary to process.
53+
old_key (str): The key to find and rename.
54+
new_key (str): The new name for the key.
55+
Returns:
56+
dict: A new dictionary with the specified keys renamed.
57+
"""
58+
new_data = {}
59+
for key, value in data.items():
60+
new_k = new_key if key == old_key else key
61+
if isinstance(value, dict):
62+
new_data[new_k] = rename_nested_keys(value, old_key, new_key)
63+
elif isinstance(value, list):
64+
new_data[new_k] = [rename_nested_keys(item, old_key, new_key) if isinstance(item, dict) else item for item in value]
65+
else:
66+
new_data[new_k] = value
67+
return new_data
68+
69+
4870
def main(raw_args=None) -> None:
4971
parser = argparse.ArgumentParser()
5072
parser.add_argument("--base_model_path", type=str, required=True)
@@ -76,7 +98,29 @@ def main(raw_args=None) -> None:
7698
"decoder_norm": {"scale": params["transformer"]["final_norm"]["scale"] + 1},
7799
},
78100
"token_embedder": {"embedding": params["transformer"]["embedder"]["input_embedding"] * jnp.sqrt(embed_dim)},
101+
"vision_encoder": {
102+
"Gemma3VisionEncoderLayer_0": {
103+
"embedding": {
104+
"bias": params["SigLiPFromPatches_0"]["siglip_encoder"]["embedding"]["bias"],
105+
"kernel": params["SigLiPFromPatches_0"]["siglip_encoder"]["embedding"]["kernel"],
106+
},
107+
"pos_embedding": params["SigLiPFromPatches_0"]["siglip_encoder"]["pos_embedding"],
108+
"Transformer": params["SigLiPFromPatches_0"]["siglip_encoder"]["Transformer"],
109+
"VisionEmbedder_0": {
110+
"mm_input_projection": params["transformer"]["embedder"]["mm_input_projection"],
111+
"mm_soft_embedding_norm": {
112+
"scale": params["transformer"]["embedder"]["mm_soft_embedding_norm"]["scale"] + 1
113+
},
114+
},
115+
}
116+
},
79117
}
118+
# Rename MlpBlock_0 to MlpBlockViT_0 in vision encoder
119+
# This is because the gemma3 model has MlpBlock in the vision encoder,
120+
# which has the same name as the MlpBlock in the MaxText decoder but different structure.
121+
# Hence, we need to rename it to avoid confusion.
122+
vision_encoder_weights = rename_nested_keys(jax_weights["vision_encoder"], "MlpBlock_0", "MlpBlockViT_0")
123+
jax_weights["vision_encoder"] = vision_encoder_weights
80124
self_attention = dict(
81125
{
82126
"query": {"kernel": []},
@@ -191,6 +235,7 @@ def astype_fn(x):
191235
if checkpoint_manager is not None:
192236
if save_checkpoint(checkpoint_manager, 0, state_new):
193237
max_logging.log("saved a checkpoint at step 0")
238+
max_logging.log(f"Checkpoint saved to: {args.maxtext_model_path}")
194239
# Upon preemption, exit when and only when all ongoing saves are complete.
195240
if checkpoint_manager.reached_preemption(0):
196241
checkpoint_manager.wait_until_finished()

MaxText/layers/gemma3.py

Lines changed: 276 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,26 +82,295 @@ def get_query_pre_attn_scalar(config) -> float:
8282
raise ValueError(f"Unsupported model name: {config.model_name}")
8383

8484

85+
def _posemb_sincos_2d(
86+
h: int,
87+
w: int,
88+
*,
89+
width: int,
90+
temperature: float = 10_000.0,
91+
dtype: jnp.dtype = jnp.float32,
92+
):
93+
"""Follows the MoCo v3 logic."""
94+
y, x = jnp.mgrid[:h, :w]
95+
96+
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
97+
omega = jnp.arange(width // 4) / (width // 4 - 1)
98+
omega = 1.0 / (temperature**omega)
99+
y = jnp.einsum("m,d->md", y.flatten(), omega)
100+
x = jnp.einsum("m,d->md", x.flatten(), omega)
101+
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
102+
return jnp.asarray(pe, dtype)[None, :, :]
103+
104+
105+
class MlpBlockViT(nn.Module):
106+
"""Transformer MLP / feed-forward block."""
107+
108+
block_id: int
109+
dtype_mm: str
110+
mlp_dim: int | None = None # Defaults to 4x input dim
111+
dropout: float = 0.0
112+
113+
@nn.compact
114+
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
115+
"""Applies Transformer MlpBlock module."""
116+
inits = dict(
117+
kernel_init=nn.initializers.xavier_uniform(),
118+
bias_init=nn.initializers.normal(stddev=1e-6),
119+
)
120+
121+
d = x.shape[-1]
122+
x = nn.Dense(features=self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
123+
x = nn.gelu(x)
124+
x = nn.Dropout(rate=self.dropout)(x, deterministic)
125+
x = nn.Dense(
126+
features=d,
127+
dtype=self.dtype_mm,
128+
**inits,
129+
)(x)
130+
return x
131+
132+
133+
class Encoder1DBlock(nn.Module):
134+
"""Single transformer encoder block (MHSA + MLP)."""
135+
136+
block_id: int
137+
dtype_mm: str
138+
mlp_dim: int | None = None # Defaults to 4x input dim
139+
num_heads: int = 12
140+
dropout: float = 0.0
141+
142+
@nn.compact
143+
def __call__(self, x: jax.Array, deterministic: bool = True) -> tuple[jax.Array, dict[str, jax.Array]]:
144+
x = nn.with_logical_constraint(x, ("activation_batch", "activation_length", "activation_embed"))
145+
y = nn.LayerNorm()(x)
146+
147+
y = nn.MultiHeadDotProductAttention(
148+
num_heads=self.num_heads,
149+
kernel_init=nn.initializers.xavier_uniform(),
150+
deterministic=deterministic,
151+
dtype=self.dtype_mm,
152+
)(y, y)
153+
y = nn.with_logical_constraint(y, ("activation_batch", "activation_length", "activation_embed"))
154+
y = nn.Dropout(rate=self.dropout)(y, deterministic)
155+
x = x + y
156+
157+
y = nn.LayerNorm()(x)
158+
y = MlpBlockViT(
159+
block_id=self.block_id,
160+
mlp_dim=self.mlp_dim,
161+
dropout=self.dropout,
162+
dtype_mm=self.dtype_mm,
163+
)(y, deterministic)
164+
y = nn.with_logical_constraint(y, ("activation_batch", "activation_length", "activation_embed"))
165+
y = nn.Dropout(rate=self.dropout)(y, deterministic)
166+
x = x + y
167+
x = nn.with_logical_constraint(x, ("activation_batch", "activation_length", "activation_embed"))
168+
return x
169+
170+
171+
class Encoder(nn.Module):
172+
"""Transformer Model Encoder for sequence to sequence translation."""
173+
174+
depth: int
175+
dtype_mm: str
176+
remat_policy: str
177+
mlp_dim: int | None = None # Defaults to 4x input dim
178+
num_heads: int = 12
179+
dropout: float = 0.0
180+
scan: bool = False
181+
182+
@nn.compact
183+
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
184+
if self.scan:
185+
block = nn.remat(
186+
Encoder1DBlock,
187+
prevent_cse=False,
188+
static_argnums=(2,), # 0=self, 2=deterministic
189+
policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
190+
)
191+
x = nn.scan(
192+
block,
193+
variable_axes={"params": 0},
194+
split_rngs={"params": True, "dropout": True},
195+
in_axes=nn.broadcast,
196+
length=self.depth,
197+
)(
198+
block_id=0,
199+
name="encoderblock",
200+
dtype_mm=self.dtype_mm,
201+
mlp_dim=self.mlp_dim,
202+
num_heads=self.num_heads,
203+
dropout=self.dropout,
204+
)(
205+
x, deterministic
206+
)
207+
else:
208+
# Input Encoder
209+
for lyr in range(self.depth):
210+
block_cur = Encoder1DBlock(
211+
block_id=lyr,
212+
name=f"encoderblock_{lyr}",
213+
dtype_mm=self.dtype_mm,
214+
mlp_dim=self.mlp_dim,
215+
num_heads=self.num_heads,
216+
dropout=self.dropout,
217+
)
218+
x = block_cur(x, deterministic)
219+
x: jax.Array = nn.LayerNorm(name="encoder_norm")(x)
220+
return x
221+
222+
223+
class Einsum(nn.Module):
224+
"""Einsum is a convenience module for parameterized tensor multiplication."""
225+
226+
shape: tuple[int, ...]
227+
weight_name: str = "w"
228+
initializer: nn.initializers.Initializer = nn.initializers.normal()
229+
dtype: jnp.dtype | None = None
230+
231+
@nn.compact
232+
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
233+
w = self.param(
234+
self.weight_name,
235+
self.initializer,
236+
self.shape,
237+
self.dtype if self.dtype is not None else None,
238+
)
239+
return jnp.einsum(eqn, x, w)
240+
241+
242+
class VisionEmbedder(nn.Module):
243+
"""Projects image embeddings to the embedding space of the text encoder."""
244+
245+
embed_dim: int
246+
vision_proj_dim: int | None = None
247+
248+
def setup(self):
249+
if self.vision_proj_dim:
250+
self.mm_soft_embedding_norm = RMSNorm()
251+
self.mm_input_projection = Einsum((self.vision_proj_dim, self.embed_dim))
252+
253+
def encode_vision(self, x: jax.Array) -> jax.Array:
254+
x = self.mm_soft_embedding_norm(x)
255+
x = self.mm_input_projection("...tm,md->...td", x)
256+
return x
257+
258+
def __call__(self, x: jax.Array) -> jax.Array:
259+
return self.encode_vision(x)
260+
261+
262+
class VisionExit(nn.Module):
263+
"""The vision exit layer.
264+
265+
Possibly downsample the soft tokens to a required output length.
266+
267+
Attributes:
268+
output_length: The embed will be spatially avg-pooled to this output length.
269+
"""
270+
271+
output_length: int = 256
272+
273+
def __call__(self, x):
274+
cur_length = x.shape[1]
275+
if cur_length == self.output_length:
276+
return x
277+
cur_width = int(cur_length**0.5)
278+
assert cur_width**2 == cur_length
279+
output_width = int(self.output_length**0.5)
280+
assert output_width**2 == self.output_length, f"Cannot pool {x.shape=} to {self.output_length}=!"
281+
batch_size = x.shape[0]
282+
embed_dim = x.shape[-1]
283+
x = jnp.reshape(x, (batch_size, cur_width, cur_width, embed_dim))
284+
assert not cur_width % output_width, f"{cur_width=} {output_width=}"
285+
window = cur_width // output_width
286+
window_shape = (window, window)
287+
x = nn.avg_pool(x, window_shape=window_shape, strides=window_shape)
288+
batch_size, height, width, embed_dim = x.shape
289+
return jnp.reshape(x, (batch_size, height * width, embed_dim))
290+
291+
85292
class Gemma3VisionEncoderLayer(nn.Module):
86293
config: Config
294+
patch_size: tuple[int, int] = (14, 14)
295+
width: int = 1152
296+
mlp_dim: int | None = 4304 # Defaults to 4x input dim
297+
depth: int = 27
298+
num_heads: int = 16
299+
posemb: str = "learn" # Can also be "sincos2d"
300+
dropout: float = 0.0
301+
# or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
302+
303+
def _get_posemb(
304+
self,
305+
typ: str,
306+
*,
307+
seqshape: tuple[int, int],
308+
width: int,
309+
name: str,
310+
dtype: jnp.dtype = jnp.float32,
311+
):
312+
"""Returns the position embedding."""
313+
if typ == "learn":
314+
shape_product = seqshape[0] * seqshape[1]
315+
return self.param(
316+
name,
317+
nn.initializers.normal(stddev=1 / (width**0.5)),
318+
(1, shape_product, width),
319+
dtype,
320+
)
321+
elif typ == "sincos2d":
322+
return _posemb_sincos_2d(*seqshape, width=width, dtype=dtype)
323+
else:
324+
raise ValueError(f"Unknown posemb type: {typ}")
87325

88326
@nn.compact
89-
def __call__(self, inputs, train=False):
327+
def __call__(self, inputs, deterministic, train=False):
90328
"""ViT model that transforms image inputs to image embeddings.
91329
Args:
92330
inputs: jnp.array shaped [B, N, H, W, C], e.g. [4, 1, 896, 896, 3]
93331
Returns:
94332
jnp.array for image embeddings, shaped [B, N, P, D], e.g. [4, 1, 256, 2560]
95333
"""
334+
cfg = self.config
96335
b, n, h, w, c = inputs.shape
97336
x = jnp.reshape(inputs, [b * n, h, w, c])
337+
# Gemma3 uses conv2d with stride 14 and kernel size 14 to extract patches.
98338
x = nn.Conv(features=1152, kernel_size=(14, 14), strides=14, padding="VALID", name="embedding")(x)
99-
jax.debug.print("x after: {}", x.mean())
100-
n, h, w, c = x.shape
101-
x = jnp.reshape(x, [n, h * w, c])
102-
# TODO(hengtaoguo): finish the ViT with posemb, dropout and transformation layers.
103-
# Currently it is only a placeholder with one Conv layer.
104-
# Placeholder x shape (B, 4096, 1152).
339+
bn, h, w, c = x.shape
340+
x = jnp.reshape(x, [bn, h * w, c])
341+
342+
# Add posemb before adding extra token.
343+
x = x + self._get_posemb(
344+
self.posemb,
345+
seqshape=(h, w),
346+
width=c,
347+
name="pos_embedding",
348+
dtype=x.dtype,
349+
)
350+
351+
x = nn.Dropout(rate=self.dropout)(x, not train)
352+
353+
# Transformer encoder to extract image features.
354+
x = Encoder(
355+
depth=self.depth,
356+
mlp_dim=self.mlp_dim,
357+
num_heads=self.num_heads,
358+
dropout=self.dropout,
359+
scan=cfg.scan_layers,
360+
remat_policy=cfg.remat_policy_for_vit,
361+
dtype_mm=cfg.dtype_mm,
362+
name="Transformer",
363+
)(x, deterministic=deterministic)
364+
365+
# Gemma3 use a vision exit layer to downsample the soft tokens to a required output length.
366+
x = VisionExit(output_length=256)(x)
367+
bn, l, c = x.shape
368+
x = jnp.reshape(x, [b, n, l, c])
369+
370+
# VisionEmbedder is a projection layer that projects the image embeddings to align with text embeddings emb_dim.
371+
x = VisionEmbedder(embed_dim=cfg.emb_dim, vision_proj_dim=self.width)(x)
372+
if cfg.freeze_vision_encoder_params:
373+
x = jax.lax.stop_gradient(x)
105374
return x
106375

107376

MaxText/layers/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,8 @@ def get_vision_encoder_layers(self):
623623
raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet")
624624

625625
@nn.compact
626-
def __call__(self, input_images):
627-
embeddings = self.vision_encoder_layer[0](config=self.config)(input_images)
626+
def __call__(self, input_images, deterministic=False):
627+
embeddings = self.vision_encoder_layer[0](config=self.config)(input_images, deterministic=deterministic)
628628
return embeddings
629629

630630

@@ -685,7 +685,7 @@ def __call__(
685685

686686
bidirectional_mask = None
687687
if self.config.use_multimodal and encoder_images is not None:
688-
image_embeddings = self.vision_encoder(input_images=encoder_images)
688+
image_embeddings = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
689689
# TODO(hengtaoguo, aireen): merge image_embeddings with decoder_input_tokens.
690690

691691
if self.config.decoder_block == DecoderBlockType.GEMMA3:

0 commit comments

Comments
 (0)