We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6de9d57 commit 33164c6Copy full SHA for 33164c6
1 file changed
src/maxdiffusion/models/attention_flax.py
@@ -979,7 +979,7 @@ def __init__(
979
precision=precision,
980
bias_init=nnx.with_partitioning(
981
nnx.initializers.zeros,
982
- ("embed",),
+ ("heads",),
983
),
984
)
985
@@ -993,7 +993,7 @@ def __init__(
993
994
995
996
997
998
999
@@ -1007,7 +1007,7 @@ def __init__(
1007
1008
1009
1010
1011
1012
1013
@@ -1021,7 +1021,7 @@ def __init__(
1021
1022
1023
1024
- ("heads",),
+ ("embed",),
1025
1026
1027
0 commit comments