Skip to content

Commit e420909

Browse files
Fix sft_llama3_demo_tpu.ipynb
1 parent 9d05b96 commit e420909

3 files changed

Lines changed: 11 additions & 7 deletions

File tree

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Checkpoint conversion utility functions."""
15+
"""Checkpoint conversion utility functions"""
1616

1717
import contextlib
1818
import gc
@@ -892,7 +892,7 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
892892
path_keys.append(val_str)
893893

894894
# Skip NNX RNG state variables (not model weights)
895-
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
895+
if "to_nnx__rngs" in path_keys or any(k == "rngs" or k.endswith("_rngs") for k in path_keys):
896896
continue
897897
# Skip if this is the "value" key itself - we want the parent path
898898
if path_keys[-1] == "value":

src/maxtext/trainers/diloco/diloco.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import drjax
2929
from flax import nnx
3030
from flax import struct
31-
from flax.training import train_state
3231
import jax
3332
import jax.numpy as jnp
3433
from jaxtyping import Array, Int32, Key, PyTree, UInt32

src/maxtext/utils/maxtext_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,10 +1668,15 @@ def _make_named_sharding(v):
16681668
rules = composite_rules(context_rules, local_rules)
16691669
raw_sharding = from_sharding_rules(out_sharding, rules)
16701670
mesh_axis_names = mesh.axis_names if mesh is not None else ()
1671-
sanitized_sharding = [
1672-
x if (x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple)) else None
1673-
for x in raw_sharding
1674-
]
1671+
1672+
def _sanitize(x):
1673+
if isinstance(x, list):
1674+
x = tuple(x)
1675+
if x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple):
1676+
return x
1677+
return None
1678+
1679+
sanitized_sharding = [_sanitize(x) for x in raw_sharding]
16751680
pspec = PartitionSpec(*sanitized_sharding)
16761681
else:
16771682
pspec = PartitionSpec(*out_sharding)

0 commit comments

Comments
 (0)