|
21 | 21 | import jax.numpy as jnp |
22 | 22 | from maxtext.configs import pyconfig |
23 | 23 | from maxtext.utils import maxtext_utils |
| 24 | +from maxtext.utils.sharding import clear_input_shardings_dump |
24 | 25 | # import optax |
25 | 26 |
|
26 | 27 | from maxtext.layers import quantizations |
27 | 28 | from maxtext.models import models |
28 | 29 | from maxtext.optimizers import optimizers |
29 | 30 | from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config |
30 | | -from tests.utils.sharding_dump import TEST_CASES, load_json, named_shardings_to_json, partition_specs_to_json |
| 31 | +from tests.utils.sharding_dump import TEST_CASES, load_json, input_sharding_to_json, named_shardings_to_json, partition_specs_to_json |
31 | 32 | from tests.utils.test_helpers import get_test_config_path |
32 | 33 | import pytest |
33 | 34 |
|
@@ -124,25 +125,34 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) |
124 | 125 | f"compile_topology={topology}", |
125 | 126 | f"compile_topology_num_slices={num_slice}", |
126 | 127 | f"model_name={model_name}", |
| 128 | + "log_config=false", |
| 129 | + "debug_sharding=true", # for input sharding dump |
127 | 130 | ] |
128 | 131 |
|
129 | 132 | root_dir = "tests/utils/sharding_info" |
130 | 133 | base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}") |
131 | 134 |
|
132 | 135 | named_json_path = os.path.join(base_path, "named_shardings.json") |
133 | 136 | logical_json_path = os.path.join(base_path, "logical_shardings.json") |
| 137 | + input_json_path = os.path.join(base_path, "input_shardings.json") |
134 | 138 |
|
135 | 139 | if not os.path.exists(named_json_path): |
136 | 140 | pytest.skip(f"Missing named_shardings.json for {model_name} {topology} slice {num_slice}") |
137 | 141 | return |
138 | 142 | if not os.path.exists(logical_json_path): |
139 | 143 | pytest.skip(f"Missing logical_shardings.json for {model_name} {topology} slice {num_slice}") |
140 | 144 | return |
| 145 | + if not os.path.exists(input_json_path): |
| 146 | + pytest.skip(f"Missing input_shardings.json for {model_name} {topology} slice {num_slice}") |
| 147 | + return |
141 | 148 |
|
142 | 149 | config = pyconfig.initialize(params) |
143 | 150 | validate_config(config) |
144 | 151 |
|
| 152 | + clear_input_shardings_dump() |
145 | 153 | topology_mesh = get_topology_mesh(config) |
| 154 | + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) |
| 155 | + optimizers.get_optimizer(config, learning_rate_schedule) |
146 | 156 | shaped_train_args, _, state_mesh_shardings, logical_shardings, _ = get_shaped_inputs(topology_mesh, config) |
147 | 157 |
|
148 | 158 | error_messages = [] |
@@ -173,6 +183,20 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) |
173 | 183 | compare_sharding_jsons(expected_logical, "Expected (Logical)", actual_logical, "Actual (Logical)") |
174 | 184 | error_messages.append(f"Logical sharding mismatch for {model_name} on {topology} slice {num_slice}") |
175 | 185 |
|
| 186 | + # 3. Compare Input Shardings |
| 187 | + actual_input = input_sharding_to_json() |
| 188 | + expected_input = load_json(input_json_path) |
| 189 | + # calculate checksum |
| 190 | + actual_input_sum = compute_checksum(actual_input) |
| 191 | + expected_input_sum = compute_checksum(expected_input) |
| 192 | + |
| 193 | + input_match = actual_input_sum == expected_input_sum |
| 194 | + |
| 195 | + if not input_match: |
| 196 | + print(f"\n[FAIL] Input Sharding Mismatch: {model_name} {topology} slice {num_slice}", flush=True) |
| 197 | + # compare_sharding_jsons(expected_input, "Expected (Input)", actual_input, "Actual (Input)") |
| 198 | + error_messages.append(f"Input sharding mismatch for {model_name} on {topology} slice {num_slice}") |
| 199 | + |
176 | 200 | assert not error_messages, "\n".join(error_messages) |
177 | 201 |
|
178 | 202 |
|
|
0 commit comments