1919from collections .abc import Sequence
2020from functools import partial
2121from typing import overload
22-
2322from etils import epath
2423from flax import nnx
2524import flax .linen as nn
2625import jax
2726import jax .numpy as jnp
28- from jax .sharding import AxisType , Mesh
27+ from jax .sharding import Mesh
2928from maxtext .configs import pyconfig
30- from maxtext .common .common_types import MODEL_MODE_TRAIN , ShardMode
29+ from maxtext .common .common_types import MODEL_MODE_TRAIN
3130from maxtext .layers import quantizations
3231from maxtext .models import models
3332from maxtext .utils import max_logging
34- from maxtext .utils import max_utils
35- from maxtext .utils import maxtext_utils
33+ from maxtext .utils import max_utils , maxtext_utils , maxtext_utils_nnx
3634from orbax import checkpoint as ocp
3735
3836try :
@@ -154,6 +152,7 @@ def from_config(
154152 mesh : Mesh | None = None ,
155153 * ,
156154 model_mode : str = MODEL_MODE_TRAIN ,
155+ rngs : None = None ,
157156) -> nn .Module :
158157 ...
159158
@@ -194,15 +193,7 @@ def from_config(
194193 model = from_config(config)
195194 """
196195 if mesh is None :
197- devices_array = maxtext_utils .create_device_mesh (config , devices )
198-
199- if config .shard_mode == ShardMode .EXPLICIT :
200- axis_types = tuple ([AxisType .Explicit ] * len (config .mesh_axes ))
201- else :
202- axis_types = tuple ([AxisType .Auto ] * len (config .mesh_axes ))
203-
204- mesh = Mesh (devices_array , config .mesh_axes , axis_types = axis_types )
205-
196+ mesh = maxtext_utils .get_mesh_from_config (config , devices )
206197 model = create_model (config , mesh , model_mode = model_mode , rngs = rngs )
207198
208199 # Return only the model
@@ -245,9 +236,7 @@ def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key
245236 """
246237
247238 def _create_model (rng_key = None ):
248- if rng_key is None :
249- rng_key = jax .random .PRNGKey (config .init_weights_seed )
250- rngs = nnx .Rngs (params = rng_key , dropout = 1 )
239+ rngs = maxtext_utils_nnx .create_nnx_rngs (config , model_mode = model_mode , rng_key = rng_key )
251240 return from_config (config , mesh = mesh , rngs = rngs , model_mode = model_mode )
252241
253242 _create_model_partial = partial (_create_model , rng_key = rng_key )
@@ -262,14 +251,7 @@ def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAI
262251 """Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""
263252
264253 def _create_model (mesh : Mesh | None = None , model_mode : str = MODEL_MODE_TRAIN , rng_key : jax .Array | None = None ):
265- if rng_key is None :
266- rng_key = jax .random .PRNGKey (config .init_weights_seed )
267-
268- if model_mode == MODEL_MODE_TRAIN :
269- rngs = nnx .Rngs (params = rng_key , dropout = 1 )
270- else :
271- rngs = nnx .Rngs (params = rng_key ) # disable dropout RNG for inference
272-
254+ rngs = maxtext_utils_nnx .create_nnx_rngs (config , model_mode = model_mode , rng_key = rng_key )
273255 return from_config (config , devices , mesh , rngs = rngs , model_mode = model_mode )
274256
275257 _create_model_partial = partial (_create_model , mesh = mesh , model_mode = model_mode , rng_key = rng_key )
@@ -282,6 +264,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN,
282264 if mesh is None :
283265 mesh = abstract_model .mesh
284266
267+ # Note for pure_nnx:
268+ # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and
269+ # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen
270+ # LogicallyPartitioned structure.
271+ # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned
272+ # structure in the abstract state and we can get the sharded state with the following code:
273+ # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh)
274+ # abstract_model = nnx.merge(graphdef, state)
275+ # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh)
276+ # sharded_state = nnx.state(model)
277+
285278 # JIT a function that creates the model state with proper sharding from the start.
286279 # By providing out_shardings, we instruct JAX to produce sharded output directly,
287280 # avoiding a large intermediate allocation on a single device.
0 commit comments