Skip to content

Commit 3ed2aad

Browse files
committed
NNX: add --pure_nnx flag to run_sharding_dump.py
- Add --pure_nnx CLI flag to run_sharding_dump.py - Propagate pure_nnx=true to the sharding_dump subprocess when flag is set - Refactor run_single_dump() to build the command as a list for conditional flag appending
1 parent d536b12 commit 3ed2aad

1 file changed

Lines changed: 18 additions & 17 deletions

File tree

tests/utils/run_sharding_dump.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,26 @@
5858
flags.DEFINE_string("model_name", None, "Specific model name to dump.")
5959
flags.DEFINE_string("topology", None, "Specific topology to dump.")
6060
flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.")
61+
flags.DEFINE_bool("pure_nnx", False, "Use pure NNX model.")
6162

6263

63-
def run_single_dump(model_name: str, topology: str, num_slice: str) -> None:
64+
def run_single_dump(model_name: str, topology: str, num_slice: str, pure_nnx: bool = False) -> None:
6465
"""Generate sharding json file for one specific model, topology and slice."""
65-
subprocess.run(
66-
[
67-
"python3",
68-
"-m",
69-
"tests.utils.sharding_dump",
70-
get_test_config_path(),
71-
f"compile_topology={topology}",
72-
f"compile_topology_num_slices={num_slice}",
73-
f"model_name={model_name}",
74-
"weight_dtype=float32",
75-
"log_config=false",
76-
"debug_sharding=true",
77-
],
78-
check=True,
79-
)
66+
cmd = [
67+
"python3",
68+
"-m",
69+
"tests.utils.sharding_dump",
70+
get_test_config_path(),
71+
f"compile_topology={topology}",
72+
f"compile_topology_num_slices={num_slice}",
73+
f"model_name={model_name}",
74+
"weight_dtype=float32",
75+
"log_config=false",
76+
"debug_sharding=true",
77+
]
78+
if pure_nnx:
79+
cmd.append("pure_nnx=true")
80+
subprocess.run(cmd, check=True)
8081

8182

8283
def main(argv: Sequence[str]) -> None:
@@ -106,7 +107,7 @@ def main(argv: Sequence[str]) -> None:
106107
print(" -> Sharding files already exist. Regenerating to overwrite.")
107108

108109
try:
109-
run_single_dump(model_name, topology, str(num_slice))
110+
run_single_dump(model_name, topology, str(num_slice), pure_nnx=FLAGS.pure_nnx)
110111
except subprocess.CalledProcessError:
111112
print(f"!!! FAILED: {model_name} {topology} {num_slice}")
112113

0 commit comments

Comments
 (0)