@@ -241,6 +241,82 @@ def deepstack_process(hidden_states, bidirectional_mask, visual_embeds):
241241 return hidden_states
242242
243243
244+ class NNXSequentialPipelineStage (nnx .Module ):
245+ """Sequential unscanned series of decoder layers formatted for a single pipeline stage."""
246+
247+ def __init__ (
248+ self , layer_cls , num_layers : int , config : Config , mesh : Mesh , quant : Quant , model_mode : str , * , rngs : nnx .Rngs
249+ ):
250+ self .config = config
251+ self .scan_layers = config .scan_layers
252+ self .num_layers = num_layers
253+ for i in range (num_layers ):
254+ layer = layer_cls (config = config , mesh = mesh , quant = quant , model_mode = model_mode , rngs = rngs )
255+ setattr (self , f"layers_{ i } " , layer )
256+
257+ def __call__ (self , inputs , decoder_segment_ids , decoder_positions , deterministic , model_mode , ** kwargs ):
258+ for i in range (self .num_layers ):
259+ layer = getattr (self , f"layers_{ i } " )
260+ out = layer (inputs , decoder_segment_ids , decoder_positions , deterministic , model_mode , ** kwargs )
261+ inputs = out [0 ] if isinstance (out , tuple ) else out
262+ if self .scan_layers :
263+ return inputs , None
264+ return inputs
265+
266+
267+ class NNXScannedPipelineStage (nnx .Module ):
268+ """Scanned block of decoder layers formatted for a single pipeline stage."""
269+
270+ def __init__ (
271+ self , layer_cls , num_layers : int , config : Config , mesh : Mesh , quant : Quant , model_mode : str , * , rngs : nnx .Rngs
272+ ):
273+ self .config = config
274+
275+ def create_layer_fn (rng ):
276+ return layer_cls (config = config , mesh = mesh , quant = quant , model_mode = model_mode , rngs = rng )
277+
278+ try :
279+ forked_rngs = rngs .fork (split = num_layers )
280+ except : # pylint: disable=bare-except
281+ forked_rngs = rngs
282+
283+ out_axes = nnx .StateAxes ({nnx .Param : config .param_scan_axis , ...: 0 })
284+ self .scanned_layers = nnx .vmap (
285+ create_layer_fn ,
286+ in_axes = 0 ,
287+ out_axes = out_axes ,
288+ axis_name = "layers_per_stage" ,
289+ transform_metadata = {nnx .PARTITION_NAME : "layers_per_stage" },
290+ )(forked_rngs )
291+
292+ def __call__ (self , inputs , decoder_segment_ids , decoder_positions , deterministic , model_mode , ** kwargs ):
293+ graphdef , params , state = nnx .split (self .scanned_layers , nnx .Param , ...)
294+
295+ scan_axis = self .config .param_scan_axis
296+ if scan_axis != 0 :
297+ params = jax .tree .map (lambda x : jnp .moveaxis (x , scan_axis , 0 ), params )
298+
299+ def layer_fn (carry , scanned_vars ):
300+ current_params , current_state = scanned_vars
301+ layer = nnx .merge (graphdef , current_params , current_state )
302+ layer_out = layer (carry , decoder_segment_ids , decoder_positions , deterministic , model_mode , ** kwargs )
303+ new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
304+ return new_carry , nnx .state (layer )
305+
306+ final_carry , scanned_state = jax .lax .scan (layer_fn , inputs , (params , state ))
307+
308+ if scan_axis != 0 :
309+ scanned_params , scanned_other = scanned_state .split (nnx .Param , ...)
310+ scanned_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), scanned_params )
311+ scanned_state = nnx .State .merge (scanned_params , scanned_other )
312+
313+ nnx .update (self .scanned_layers , scanned_state )
314+
315+ if self .config .scan_layers :
316+ return final_carry , None
317+ return final_carry
318+
319+
244320class NNXDecoder (nnx .Module ):
245321 """A stack of decoder layers as a part of an encoder-decoder architecture, using NNX."""
246322
@@ -992,7 +1068,6 @@ def __call__(
9921068 previous_chunk = None ,
9931069 slot : None | int = None ,
9941070 page_state : None | page_manager .PageState = None ,
995- multimodal_input : None | Any = None ,
9961071 kv_caches : list [jax .Array ] | None = None ,
9971072 attention_metadata = None ,
9981073 deepstack_visual_embeds : None | list [jnp .ndarray ] = None ,
0 commit comments