Skip to content

Commit 0a1f4e9

Browse files
xibinliuecnal-cienet
authored andcommitted
NNX: update train/eval step sharding signatures to omit rng for pure_nnx
- get_functional_train_with_signature: use (state, batch) shardings when pure_nnx=True - get_functional_eval_with_signature: use (state, batch) shardings when pure_nnx=True
1 parent d901179 commit 0a1f4e9

3 files changed

Lines changed: 10 additions & 2 deletions

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ def get_functional_train_with_signature(
9393
"""Get the shardings (both state and data) for `train_step`."""
9494
functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
9595
functional_train.__name__ = "train_step"
96-
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
96+
if config.pure_nnx:
97+
in_shardings = (state_mesh_shardings, data_sharding) # State, batch
98+
else:
99+
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
97100
out_shardings = (state_mesh_shardings, None) # State, metrics
98101
static_argnums = () # We partial out the static argnums of model and config
99102
donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory.
@@ -104,7 +107,10 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar
104107
"""Get the shardings (both state and data) for `eval_step`."""
105108
functional_eval = functools.partial(eval_step, model, config)
106109
functional_eval.__name__ = "eval_step"
107-
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
110+
if config.pure_nnx:
111+
in_shardings = (state_mesh_shardings, data_sharding) # State, batch
112+
else:
113+
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
108114
out_shardings = None # metrics
109115
static_argnums = () # We partial out the static argnums of model, config
110116
donate_argnums = () # state will be kept instead of being donated in eval_step

tests/integration/aot_identical_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
179179
"enable_checkpointing=False",
180180
"dump_jaxpr=True",
181181
"dump_jaxpr_delete_local_after=False",
182+
"skip_first_n_steps_for_profiler=0",
182183
]
183184
if extra_args:
184185
shared_args.extend(extra_args)

tests/integration/xaot_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def run_compile_then_load(self, test_name, *extra_args):
8080
"learning_rate=1e-3",
8181
"dataset_type=synthetic",
8282
"enable_checkpointing=False",
83+
"profiler=''",
8384
]
8485

8586
if extra_args:

0 commit comments

Comments
 (0)