Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/maxtext/utils/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
_ACTIVATION_SHARDINGS_DUMP = []


def clear_input_shardings_dump():
"""Clear the input shardings dump"""
_LOGGED_ACTIVATION_SHARDINGS.clear()
_ACTIVATION_SHARDINGS_DUMP.clear()


def get_input_data_sharding(config, mesh):
"""Get the input data sharding for the model"""
if config.enable_diloco:
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/sharding_compare_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import jax.numpy as jnp
from maxtext.configs import pyconfig
from maxtext.utils import maxtext_utils
from maxtext.utils.sharding import clear_input_shardings_dump
# import optax

from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.optimizers import optimizers
from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config
from tests.utils.sharding_dump import TEST_CASES, load_json, named_shardings_to_json, partition_specs_to_json
from tests.utils.sharding_dump import TEST_CASES, load_json, input_sharding_to_json, named_shardings_to_json, partition_specs_to_json
from tests.utils.test_helpers import get_test_config_path
import pytest

Expand Down Expand Up @@ -124,25 +125,34 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
f"compile_topology={topology}",
f"compile_topology_num_slices={num_slice}",
f"model_name={model_name}",
"log_config=false",
"debug_sharding=true", # for input sharding dump
]

root_dir = "tests/utils/sharding_info"
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}")

named_json_path = os.path.join(base_path, "named_shardings.json")
logical_json_path = os.path.join(base_path, "logical_shardings.json")
input_json_path = os.path.join(base_path, "input_shardings.json")

if not os.path.exists(named_json_path):
pytest.skip(f"Missing named_shardings.json for {model_name} {topology} slice {num_slice}")
return
if not os.path.exists(logical_json_path):
pytest.skip(f"Missing logical_shardings.json for {model_name} {topology} slice {num_slice}")
return
if not os.path.exists(input_json_path):
pytest.skip(f"Missing input_shardings.json for {model_name} {topology} slice {num_slice}")
return

config = pyconfig.initialize(params)
validate_config(config)

clear_input_shardings_dump()
topology_mesh = get_topology_mesh(config)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
optimizers.get_optimizer(config, learning_rate_schedule)
shaped_train_args, _, state_mesh_shardings, logical_shardings, _ = get_shaped_inputs(topology_mesh, config)

error_messages = []
Expand Down Expand Up @@ -173,6 +183,20 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
compare_sharding_jsons(expected_logical, "Expected (Logical)", actual_logical, "Actual (Logical)")
error_messages.append(f"Logical sharding mismatch for {model_name} on {topology} slice {num_slice}")

# 3. Compare Input Shardings
actual_input = input_sharding_to_json()
expected_input = load_json(input_json_path)
# calculate checksum
actual_input_sum = compute_checksum(actual_input)
expected_input_sum = compute_checksum(expected_input)

input_match = actual_input_sum == expected_input_sum

if not input_match:
print(f"\n[FAIL] Input Sharding Mismatch: {model_name} {topology} slice {num_slice}", flush=True)
# compare_sharding_jsons(expected_input, "Expected (Input)", actual_input, "Actual (Input)")
error_messages.append(f"Input sharding mismatch for {model_name} on {topology} slice {num_slice}")

assert not error_messages, "\n".join(error_messages)


Expand Down
11 changes: 2 additions & 9 deletions tests/utils/sharding_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,17 +388,10 @@ def partition_specs_to_json(logical_tree, shape_tree) -> dict[str, Any]:
def input_sharding_to_json() -> dict[str, Any]:
input_sharding = {}
input_sharding["Activation Sharding Dump"] = _ACTIVATION_SHARDINGS_DUMP
print(f"Got {len(_ACTIVATION_SHARDINGS_DUMP)} Input entries.")
return input_sharding


def save_activation_shading_dict(output_path: str | Path, sharding_dict: dict) -> None:
"""Save the activation sharding dict directly to a JSON file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(sharding_dict, f, indent=2)


def save_json(output_path: str | Path, sharding_dict: dict) -> None:
"""Save dict to a JSON file."""
output_path = Path(output_path)
Expand All @@ -408,7 +401,7 @@ def save_json(output_path: str | Path, sharding_dict: dict) -> None:


def load_json(json_path: str | Path) -> dict:
"""Loads the named_shardings.json file into a plain Python dict."""
"""Loads json file into a plain Python dict."""
json_path = Path(json_path)
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
Expand Down
Loading