Skip to content

Commit 0c1e488

Browse files
Fix unspecified mesh-axis integer validation
1 parent 2e6cd11 commit 0c1e488

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

src/maxtext/utils/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
455455
determined_val = target_product / np.prod(parallelism_vals) * -1
456456

457457
assert (
458-
determined_val >= 1 and determined_val.is_integer
458+
determined_val >= 1 and determined_val.is_integer()
459459
), f"Unspecified value unable to be determined with the given\
460460
{parallelism_type} parallelism values"
461461

tests/unit/max_utils_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ def test_invalid_strategy(self):
137137
max_utils.is_valid_custom_mesh([1, 1, 1, 1, 1, 16, 16, 1], "invalid_strategy")
138138

139139

140+
class FillUnspecifiedMeshAxesTest(unittest.TestCase):
141+
"""Tests for fill_unspecified_mesh_axes."""
142+
143+
def test_rejects_non_integer_unspecified_value(self):
144+
with self.assertRaises(AssertionError):
145+
max_utils.fill_unspecified_mesh_axes([2, -1, 3], 10, "ICI")
146+
147+
def test_fills_integer_unspecified_value(self):
148+
self.assertEqual(
149+
max_utils.fill_unspecified_mesh_axes([2, -1, 5], 20, "ICI"),
150+
[2, 2, 5],
151+
)
152+
153+
140154
class UnscanTest(unittest.TestCase):
141155
"""Test unscanning utility."""
142156

0 commit comments

Comments
 (0)