Skip to content

Commit e54dd73

Browse files
Implement and update the following models in NNX decoder: DeepSeek/Gemma3/Llama4
1 parent 7d6e1ca commit e54dd73

6 files changed

Lines changed: 604 additions & 219 deletions

File tree

src/maxtext/layers/initializers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ def variable_to_logically_partitioned(variable: nnx.Variable):
9595
out_sharding = metadata["sharding"]
9696

9797
if out_sharding is not None:
98+
if nnx.PARTITION_NAME in metadata:
99+
partition_name = metadata[nnx.PARTITION_NAME]
100+
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
101+
102+
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
103+
if partition_name not in sharding_list:
104+
sharding_list.insert(scan_axis, partition_name)
105+
106+
out_sharding = tuple(sharding_list)
107+
98108
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
99109
val,
100110
out_sharding, # type: ignore[arg-type]

0 commit comments

Comments
 (0)