Skip to content

Commit afa96ee

Browse files
hsuan-lun-chiangecnal-cienet
authored andcommitted
Fix sft_llama3_demo_tpu.ipynb
1 parent 3324f60 commit afa96ee

3 files changed

Lines changed: 10 additions & 6 deletions

File tree

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
895895
path_keys.append(val_str)
896896

897897
# Skip NNX RNG state variables (not model weights)
898-
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
898+
if "to_nnx__rngs" in path_keys or any(k == "rngs" or k.endswith("_rngs") for k in path_keys):
899899
continue
900900
# Skip if this is the "value" key itself - we want the parent path
901901
if path_keys[-1] == "value":

src/maxtext/trainers/diloco/diloco.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import drjax
2929
from flax import nnx
3030
from flax import struct
31-
from flax.training import train_state
3231
import jax
3332
import jax.numpy as jnp
3433
from jaxtyping import Array, Int32, Key, PyTree, UInt32

src/maxtext/utils/maxtext_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,10 +1667,15 @@ def _make_named_sharding(v):
16671667
rules = composite_rules(context_rules, local_rules)
16681668
raw_sharding = from_sharding_rules(out_sharding, rules)
16691669
mesh_axis_names = mesh.axis_names if mesh is not None else ()
1670-
sanitized_sharding = [
1671-
x if (x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple)) else None
1672-
for x in raw_sharding
1673-
]
1670+
1671+
def _sanitize(x):
1672+
if isinstance(x, list):
1673+
x = tuple(x)
1674+
if x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple):
1675+
return x
1676+
return None
1677+
1678+
sanitized_sharding = [_sanitize(x) for x in raw_sharding]
16741679
pspec = PartitionSpec(*sanitized_sharding)
16751680
else:
16761681
pspec = PartitionSpec(*out_sharding)

0 commit comments

Comments
 (0)