diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index b890e2f8b4..7de58ed50e 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -36,6 +36,12 @@ _ACTIVATION_SHARDINGS_DUMP = [] +def clear_input_shardings_dump(): + """Clear the input shardings dump""" + _LOGGED_ACTIVATION_SHARDINGS.clear() + _ACTIVATION_SHARDINGS_DUMP.clear() + + def get_input_data_sharding(config, mesh): """Get the input data sharding for the model""" if config.enable_diloco: diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index 8158ae7f47..2cd696f241 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -21,13 +21,14 @@ import jax.numpy as jnp from maxtext.configs import pyconfig from maxtext.utils import maxtext_utils +from maxtext.utils.sharding import clear_input_shardings_dump # import optax from maxtext.layers import quantizations from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config -from tests.utils.sharding_dump import TEST_CASES, load_json, named_shardings_to_json, partition_specs_to_json +from tests.utils.sharding_dump import TEST_CASES, load_json, input_sharding_to_json, named_shardings_to_json, partition_specs_to_json from tests.utils.test_helpers import get_test_config_path import pytest @@ -124,6 +125,8 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) f"compile_topology={topology}", f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", + "log_config=false", + "debug_sharding=true", # for input sharding dump ] root_dir = "tests/utils/sharding_info" @@ -131,6 +134,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) named_json_path = os.path.join(base_path, "named_shardings.json") logical_json_path = os.path.join(base_path, "logical_shardings.json") + input_json_path = os.path.join(base_path, "input_shardings.json") if not os.path.exists(named_json_path): pytest.skip(f"Missing named_shardings.json for {model_name} {topology} slice {num_slice}") @@ -138,11 +142,17 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) if not os.path.exists(logical_json_path): pytest.skip(f"Missing logical_shardings.json for {model_name} {topology} slice {num_slice}") return + if not os.path.exists(input_json_path): + pytest.skip(f"Missing input_shardings.json for {model_name} {topology} slice {num_slice}") + return config = pyconfig.initialize(params) validate_config(config) + clear_input_shardings_dump() topology_mesh = get_topology_mesh(config) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + optimizers.get_optimizer(config, learning_rate_schedule) shaped_train_args, _, state_mesh_shardings, logical_shardings, _ = get_shaped_inputs(topology_mesh, config) error_messages = [] @@ -173,6 +183,20 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) compare_sharding_jsons(expected_logical, "Expected (Logical)", actual_logical, "Actual (Logical)") error_messages.append(f"Logical sharding mismatch for {model_name} on {topology} slice {num_slice}") + # 3. Compare Input Shardings + actual_input = input_sharding_to_json() + expected_input = load_json(input_json_path) + # calculate checksum + actual_input_sum = compute_checksum(actual_input) + expected_input_sum = compute_checksum(expected_input) + + input_match = actual_input_sum == expected_input_sum + + if not input_match: + print(f"\n[FAIL] Input Sharding Mismatch: {model_name} {topology} slice {num_slice}", flush=True) + # compare_sharding_jsons(expected_input, "Expected (Input)", actual_input, "Actual (Input)") + error_messages.append(f"Input sharding mismatch for {model_name} on {topology} slice {num_slice}") + assert not error_messages, "\n".join(error_messages) diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index 4eebe55322..0b0cfd4628 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -388,17 +388,10 @@ def partition_specs_to_json(logical_tree, shape_tree) -> dict[str, Any]: def input_sharding_to_json() -> dict[str, Any]: input_sharding = {} input_sharding["Activation Sharding Dump"] = _ACTIVATION_SHARDINGS_DUMP + print(f"Got {len(_ACTIVATION_SHARDINGS_DUMP)} Input entries.") return input_sharding -def save_activation_shading_dict(output_path: str | Path, sharding_dict: dict) -> None: - """Save the activation sharding dict directly to a JSON file.""" - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(sharding_dict, f, indent=2) - - def save_json(output_path: str | Path, sharding_dict: dict) -> None: """Save dict to a JSON file.""" output_path = Path(output_path) @@ -408,7 +401,7 @@ def save_json(output_path: str | Path, sharding_dict: dict) -> None: def load_json(json_path: str | Path) -> dict: - """Loads the named_shardings.json file into a plain Python dict.""" + """Loads json file into a plain Python dict.""" json_path = Path(json_path) with open(json_path, "r", encoding="utf-8") as f: return json.load(f)