Skip to content

Commit 33164c6

Browse files
authored
Change bias initialization from 'embed' to 'heads'
Fix the bias sharding axis, it should be output axis instead of input one.
1 parent 6de9d57 commit 33164c6

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def __init__(
979979
precision=precision,
980980
bias_init=nnx.with_partitioning(
981981
nnx.initializers.zeros,
982-
("embed",),
982+
("heads",),
983983
),
984984
)
985985

@@ -993,7 +993,7 @@ def __init__(
993993
precision=precision,
994994
bias_init=nnx.with_partitioning(
995995
nnx.initializers.zeros,
996-
("embed",),
996+
("heads",),
997997
),
998998
)
999999

@@ -1007,7 +1007,7 @@ def __init__(
10071007
precision=precision,
10081008
bias_init=nnx.with_partitioning(
10091009
nnx.initializers.zeros,
1010-
("embed",),
1010+
("heads",),
10111011
),
10121012
)
10131013

@@ -1021,7 +1021,7 @@ def __init__(
10211021
precision=precision,
10221022
bias_init=nnx.with_partitioning(
10231023
nnx.initializers.zeros,
1024-
("heads",),
1024+
("embed",),
10251025
),
10261026
)
10271027

0 commit comments

Comments
 (0)