@@ -82,26 +82,295 @@ def get_query_pre_attn_scalar(config) -> float:
8282 raise ValueError (f"Unsupported model name: { config .model_name } " )
8383
8484
85+ def _posemb_sincos_2d (
86+ h : int ,
87+ w : int ,
88+ * ,
89+ width : int ,
90+ temperature : float = 10_000.0 ,
91+ dtype : jnp .dtype = jnp .float32 ,
92+ ):
93+ """Follows the MoCo v3 logic."""
94+ y , x = jnp .mgrid [:h , :w ]
95+
96+ assert width % 4 == 0 , "Width must be mult of 4 for sincos posemb"
97+ omega = jnp .arange (width // 4 ) / (width // 4 - 1 )
98+ omega = 1.0 / (temperature ** omega )
99+ y = jnp .einsum ("m,d->md" , y .flatten (), omega )
100+ x = jnp .einsum ("m,d->md" , x .flatten (), omega )
101+ pe = jnp .concatenate ([jnp .sin (x ), jnp .cos (x ), jnp .sin (y ), jnp .cos (y )], axis = 1 )
102+ return jnp .asarray (pe , dtype )[None , :, :]
103+
104+
105+ class MlpBlockViT (nn .Module ):
106+ """Transformer MLP / feed-forward block."""
107+
108+ block_id : int
109+ dtype_mm : str
110+ mlp_dim : int | None = None # Defaults to 4x input dim
111+ dropout : float = 0.0
112+
113+ @nn .compact
114+ def __call__ (self , x : jax .Array , deterministic : bool = True ) -> jax .Array :
115+ """Applies Transformer MlpBlock module."""
116+ inits = dict (
117+ kernel_init = nn .initializers .xavier_uniform (),
118+ bias_init = nn .initializers .normal (stddev = 1e-6 ),
119+ )
120+
121+ d = x .shape [- 1 ]
122+ x = nn .Dense (features = self .mlp_dim or 4 * d , dtype = self .dtype_mm , ** inits )(x )
123+ x = nn .gelu (x )
124+ x = nn .Dropout (rate = self .dropout )(x , deterministic )
125+ x = nn .Dense (
126+ features = d ,
127+ dtype = self .dtype_mm ,
128+ ** inits ,
129+ )(x )
130+ return x
131+
132+
133+ class Encoder1DBlock (nn .Module ):
134+ """Single transformer encoder block (MHSA + MLP)."""
135+
136+ block_id : int
137+ dtype_mm : str
138+ mlp_dim : int | None = None # Defaults to 4x input dim
139+ num_heads : int = 12
140+ dropout : float = 0.0
141+
142+ @nn .compact
143+ def __call__ (self , x : jax .Array , deterministic : bool = True ) -> tuple [jax .Array , dict [str , jax .Array ]]:
144+ x = nn .with_logical_constraint (x , ("activation_batch" , "activation_length" , "activation_embed" ))
145+ y = nn .LayerNorm ()(x )
146+
147+ y = nn .MultiHeadDotProductAttention (
148+ num_heads = self .num_heads ,
149+ kernel_init = nn .initializers .xavier_uniform (),
150+ deterministic = deterministic ,
151+ dtype = self .dtype_mm ,
152+ )(y , y )
153+ y = nn .with_logical_constraint (y , ("activation_batch" , "activation_length" , "activation_embed" ))
154+ y = nn .Dropout (rate = self .dropout )(y , deterministic )
155+ x = x + y
156+
157+ y = nn .LayerNorm ()(x )
158+ y = MlpBlockViT (
159+ block_id = self .block_id ,
160+ mlp_dim = self .mlp_dim ,
161+ dropout = self .dropout ,
162+ dtype_mm = self .dtype_mm ,
163+ )(y , deterministic )
164+ y = nn .with_logical_constraint (y , ("activation_batch" , "activation_length" , "activation_embed" ))
165+ y = nn .Dropout (rate = self .dropout )(y , deterministic )
166+ x = x + y
167+ x = nn .with_logical_constraint (x , ("activation_batch" , "activation_length" , "activation_embed" ))
168+ return x
169+
170+
171+ class Encoder (nn .Module ):
172+ """Transformer Model Encoder for sequence to sequence translation."""
173+
174+ depth : int
175+ dtype_mm : str
176+ remat_policy : str
177+ mlp_dim : int | None = None # Defaults to 4x input dim
178+ num_heads : int = 12
179+ dropout : float = 0.0
180+ scan : bool = False
181+
182+ @nn .compact
183+ def __call__ (self , x : jax .Array , deterministic : bool = True ) -> jax .Array :
184+ if self .scan :
185+ block = nn .remat (
186+ Encoder1DBlock ,
187+ prevent_cse = False ,
188+ static_argnums = (2 ,), # 0=self, 2=deterministic
189+ policy = getattr (jax .checkpoint_policies , self .remat_policy , None ),
190+ )
191+ x = nn .scan (
192+ block ,
193+ variable_axes = {"params" : 0 },
194+ split_rngs = {"params" : True , "dropout" : True },
195+ in_axes = nn .broadcast ,
196+ length = self .depth ,
197+ )(
198+ block_id = 0 ,
199+ name = "encoderblock" ,
200+ dtype_mm = self .dtype_mm ,
201+ mlp_dim = self .mlp_dim ,
202+ num_heads = self .num_heads ,
203+ dropout = self .dropout ,
204+ )(
205+ x , deterministic
206+ )
207+ else :
208+ # Input Encoder
209+ for lyr in range (self .depth ):
210+ block_cur = Encoder1DBlock (
211+ block_id = lyr ,
212+ name = f"encoderblock_{ lyr } " ,
213+ dtype_mm = self .dtype_mm ,
214+ mlp_dim = self .mlp_dim ,
215+ num_heads = self .num_heads ,
216+ dropout = self .dropout ,
217+ )
218+ x = block_cur (x , deterministic )
219+ x : jax .Array = nn .LayerNorm (name = "encoder_norm" )(x )
220+ return x
221+
222+
223+ class Einsum (nn .Module ):
224+ """Einsum is a convenience module for parameterized tensor multiplication."""
225+
226+ shape : tuple [int , ...]
227+ weight_name : str = "w"
228+ initializer : nn .initializers .Initializer = nn .initializers .normal ()
229+ dtype : jnp .dtype | None = None
230+
231+ @nn .compact
232+ def __call__ (self , eqn : str , x : jax .Array ) -> jax .Array :
233+ w = self .param (
234+ self .weight_name ,
235+ self .initializer ,
236+ self .shape ,
237+ self .dtype if self .dtype is not None else None ,
238+ )
239+ return jnp .einsum (eqn , x , w )
240+
241+
242+ class VisionEmbedder (nn .Module ):
243+ """Projects image embeddings to the embedding space of the text encoder."""
244+
245+ embed_dim : int
246+ vision_proj_dim : int | None = None
247+
248+ def setup (self ):
249+ if self .vision_proj_dim :
250+ self .mm_soft_embedding_norm = RMSNorm ()
251+ self .mm_input_projection = Einsum ((self .vision_proj_dim , self .embed_dim ))
252+
253+ def encode_vision (self , x : jax .Array ) -> jax .Array :
254+ x = self .mm_soft_embedding_norm (x )
255+ x = self .mm_input_projection ("...tm,md->...td" , x )
256+ return x
257+
258+ def __call__ (self , x : jax .Array ) -> jax .Array :
259+ return self .encode_vision (x )
260+
261+
262+ class VisionExit (nn .Module ):
263+ """The vision exit layer.
264+
265+ Possibly downsample the soft tokens to a required output length.
266+
267+ Attributes:
268+ output_length: The embed will be spatially avg-pooled to this output length.
269+ """
270+
271+ output_length : int = 256
272+
273+ def __call__ (self , x ):
274+ cur_length = x .shape [1 ]
275+ if cur_length == self .output_length :
276+ return x
277+ cur_width = int (cur_length ** 0.5 )
278+ assert cur_width ** 2 == cur_length
279+ output_width = int (self .output_length ** 0.5 )
280+ assert output_width ** 2 == self .output_length , f"Cannot pool { x .shape = } to { self .output_length } =!"
281+ batch_size = x .shape [0 ]
282+ embed_dim = x .shape [- 1 ]
283+ x = jnp .reshape (x , (batch_size , cur_width , cur_width , embed_dim ))
284+ assert not cur_width % output_width , f"{ cur_width = } { output_width = } "
285+ window = cur_width // output_width
286+ window_shape = (window , window )
287+ x = nn .avg_pool (x , window_shape = window_shape , strides = window_shape )
288+ batch_size , height , width , embed_dim = x .shape
289+ return jnp .reshape (x , (batch_size , height * width , embed_dim ))
290+
291+
85292class Gemma3VisionEncoderLayer (nn .Module ):
86293 config : Config
294+ patch_size : tuple [int , int ] = (14 , 14 )
295+ width : int = 1152
296+ mlp_dim : int | None = 4304 # Defaults to 4x input dim
297+ depth : int = 27
298+ num_heads : int = 16
299+ posemb : str = "learn" # Can also be "sincos2d"
300+ dropout : float = 0.0
301+ # or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
302+
303+ def _get_posemb (
304+ self ,
305+ typ : str ,
306+ * ,
307+ seqshape : tuple [int , int ],
308+ width : int ,
309+ name : str ,
310+ dtype : jnp .dtype = jnp .float32 ,
311+ ):
312+ """Returns the position embedding."""
313+ if typ == "learn" :
314+ shape_product = seqshape [0 ] * seqshape [1 ]
315+ return self .param (
316+ name ,
317+ nn .initializers .normal (stddev = 1 / (width ** 0.5 )),
318+ (1 , shape_product , width ),
319+ dtype ,
320+ )
321+ elif typ == "sincos2d" :
322+ return _posemb_sincos_2d (* seqshape , width = width , dtype = dtype )
323+ else :
324+ raise ValueError (f"Unknown posemb type: { typ } " )
87325
88326 @nn .compact
89- def __call__ (self , inputs , train = False ):
327+ def __call__ (self , inputs , deterministic , train = False ):
90328 """ViT model that transforms image inputs to image embeddings.
91329 Args:
92330 inputs: jnp.array shaped [B, N, H, W, C], e.g. [4, 1, 896, 896, 3]
93331 Returns:
94332 jnp.array for image embeddings, shaped [B, N, P, D], e.g. [4, 1, 256, 2560]
95333 """
334+ cfg = self .config
96335 b , n , h , w , c = inputs .shape
97336 x = jnp .reshape (inputs , [b * n , h , w , c ])
337+ # Gemma3 uses conv2d with stride 14 and kernel size 14 to extract patches.
98338 x = nn .Conv (features = 1152 , kernel_size = (14 , 14 ), strides = 14 , padding = "VALID" , name = "embedding" )(x )
99- jax .debug .print ("x after: {}" , x .mean ())
100- n , h , w , c = x .shape
101- x = jnp .reshape (x , [n , h * w , c ])
102- # TODO(hengtaoguo): finish the ViT with posemb, dropout and transformation layers.
103- # Currently it is only a placeholder with one Conv layer.
104- # Placeholder x shape (B, 4096, 1152).
339+ bn , h , w , c = x .shape
340+ x = jnp .reshape (x , [bn , h * w , c ])
341+
342+ # Add posemb before adding extra token.
343+ x = x + self ._get_posemb (
344+ self .posemb ,
345+ seqshape = (h , w ),
346+ width = c ,
347+ name = "pos_embedding" ,
348+ dtype = x .dtype ,
349+ )
350+
351+ x = nn .Dropout (rate = self .dropout )(x , not train )
352+
353+ # Transformer encoder to extract image features.
354+ x = Encoder (
355+ depth = self .depth ,
356+ mlp_dim = self .mlp_dim ,
357+ num_heads = self .num_heads ,
358+ dropout = self .dropout ,
359+ scan = cfg .scan_layers ,
360+ remat_policy = cfg .remat_policy_for_vit ,
361+ dtype_mm = cfg .dtype_mm ,
362+ name = "Transformer" ,
363+ )(x , deterministic = deterministic )
364+
365+ # Gemma3 use a vision exit layer to downsample the soft tokens to a required output length.
366+ x = VisionExit (output_length = 256 )(x )
367+ bn , l , c = x .shape
368+ x = jnp .reshape (x , [b , n , l , c ])
369+
370+ # VisionEmbedder is a projection layer that projects the image embeddings to align with text embeddings emb_dim.
371+ x = VisionEmbedder (embed_dim = cfg .emb_dim , vision_proj_dim = self .width )(x )
372+ if cfg .freeze_vision_encoder_params :
373+ x = jax .lax .stop_gradient (x )
105374 return x
106375
107376
0 commit comments