Skip to content

Commit fa0d008

Browse files
Merge pull request #3114 from CIeNET-International:feat/Migrate-Decoder-And-Tests-to-NNX
PiperOrigin-RevId: 911461579
2 parents d1e82b2 + e54dd73 commit fa0d008

6 files changed

Lines changed: 604 additions & 219 deletions

File tree

src/maxtext/layers/initializers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ def variable_to_logically_partitioned(variable: nnx.Variable):
9595
out_sharding = metadata["sharding"]
9696

9797
if out_sharding is not None:
98+
if nnx.PARTITION_NAME in metadata:
99+
partition_name = metadata[nnx.PARTITION_NAME]
100+
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
101+
102+
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
103+
if partition_name not in sharding_list:
104+
sharding_list.insert(scan_axis, partition_name)
105+
106+
out_sharding = tuple(sharding_list)
107+
98108
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
99109
val,
100110
out_sharding, # type: ignore[arg-type]

0 commit comments

Comments
 (0)