|
58 | 58 | flags.DEFINE_string("model_name", None, "Specific model name to dump.") |
59 | 59 | flags.DEFINE_string("topology", None, "Specific topology to dump.") |
60 | 60 | flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.") |
| 61 | +flags.DEFINE_bool("pure_nnx", False, "Use pure NNX model.") |
61 | 62 |
|
62 | 63 |
|
63 | | -def run_single_dump(model_name: str, topology: str, num_slice: str) -> None: |
| 64 | +def run_single_dump(model_name: str, topology: str, num_slice: str, pure_nnx: bool = False) -> None: |
64 | 65 | """Generate sharding json file for one specific model, topology and slice.""" |
65 | | - subprocess.run( |
66 | | - [ |
67 | | - "python3", |
68 | | - "-m", |
69 | | - "tests.utils.sharding_dump", |
70 | | - get_test_config_path(), |
71 | | - f"compile_topology={topology}", |
72 | | - f"compile_topology_num_slices={num_slice}", |
73 | | - f"model_name={model_name}", |
74 | | - "weight_dtype=float32", |
75 | | - "log_config=false", |
76 | | - "debug_sharding=true", |
77 | | - ], |
78 | | - check=True, |
79 | | - ) |
| 66 | + cmd = [ |
| 67 | + "python3", |
| 68 | + "-m", |
| 69 | + "tests.utils.sharding_dump", |
| 70 | + get_test_config_path(), |
| 71 | + f"compile_topology={topology}", |
| 72 | + f"compile_topology_num_slices={num_slice}", |
| 73 | + f"model_name={model_name}", |
| 74 | + "weight_dtype=float32", |
| 75 | + "log_config=false", |
| 76 | + "debug_sharding=true", |
| 77 | + ] |
| 78 | + if pure_nnx: |
| 79 | + cmd.append("pure_nnx=true") |
| 80 | + subprocess.run(cmd, check=True) |
80 | 81 |
|
81 | 82 |
|
82 | 83 | def main(argv: Sequence[str]) -> None: |
@@ -106,7 +107,7 @@ def main(argv: Sequence[str]) -> None: |
106 | 107 | print(" -> Sharding files already exist. Regenerating to overwrite.") |
107 | 108 |
|
108 | 109 | try: |
109 | | - run_single_dump(model_name, topology, str(num_slice)) |
| 110 | + run_single_dump(model_name, topology, str(num_slice), pure_nnx=FLAGS.pure_nnx) |
110 | 111 | except subprocess.CalledProcessError: |
111 | 112 | print(f"!!! FAILED: {model_name} {topology} {num_slice}") |
112 | 113 |
|
|
0 commit comments