Skip to content

Commit 73a9ac5

Browse files
author
The gemma Authors
committed
optional PLE token input
PiperOrigin-RevId: 906871699
1 parent ae84d95 commit 73a9ac5

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

gemma/gm/nn/gemma4/_modules.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,20 @@ def encode_audio(self, x: jax.Array) -> jax.Array:
166166
x = self.audio_soft_embedding_norm(x)
167167
return x
168168

169-
def encode_per_layer_input(self, x: jax.Array, t: jax.Array) -> jax.Array:
169+
def encode_per_layer_input(
170+
self,
171+
x: jax.Array,
172+
t: jax.Array,
173+
ignore_ple_tokens: bool = False,
174+
) -> jax.Array:
170175
"""Encodes the input tokens.
171176
172177
Args:
173178
x: Input shape [seq_len, embed_dim] or [batch_size, seq_len, embed_dim].
174179
t: Input tokens of shape [seq_len] or [batch_size, seq_len], where each
175180
token is an integer in [0, vocab_size).
181+
ignore_ple_tokens: If True, the tokens are not used to compute the per
182+
layer input embeddings.
176183
177184
Returns:
178185
Encoded input of shape [seq_len, num_layers, per_layer_input_dim] or
@@ -184,6 +191,8 @@ def encode_per_layer_input(self, x: jax.Array, t: jax.Array) -> jax.Array:
184191
)
185192
x = self.per_layer_model_projection('...td,dnp->...tnp', x)
186193
x = self.per_layer_projection_norm(x)
194+
if ignore_ple_tokens:
195+
return x
187196
y = self.per_layer_input_embedding_table[(t,)]
188197
y *= jnp.sqrt(self.per_layer_input_dim).astype(y.dtype)
189198
return (x + y) * jax.lax.rsqrt(2.0).astype(x.dtype)

gemma/gm/nn/gemma4/_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def _encode_and_get_inputs(
421421
attention_mask=None,
422422
positions=None,
423423
audio_soft_token_counts=None,
424+
ignore_ple_tokens: bool = False,
424425
) -> _Inputs:
425426
"""Encode the text tokens, eventually including the vision embeddings."""
426427
if images is not None or audio is not None:
@@ -467,7 +468,9 @@ def _encode_and_get_inputs(
467468
)
468469

469470
if self.config.per_layer_input_dim:
470-
per_layer_inputs = self.embedder.encode_per_layer_input(x, tokens)
471+
per_layer_inputs = self.embedder.encode_per_layer_input(
472+
x, tokens, ignore_ple_tokens=ignore_ple_tokens
473+
)
471474
else:
472475
per_layer_inputs = None
473476

@@ -501,7 +504,7 @@ def _encode_and_get_inputs(
501504

502505
if self.config.per_layer_input_dim:
503506
per_layer_inputs = self.embedder.encode_per_layer_input(
504-
x, inputs.tokens_with_mm
507+
x, inputs.tokens_with_mm, ignore_ple_tokens=ignore_ple_tokens
505508
)
506509
else:
507510
per_layer_inputs = None

0 commit comments

Comments
 (0)