We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3fe7794 commit d2bef3cCopy full SHA for d2bef3c
1 file changed
python/mlx/nn/layers/distributed.py
@@ -86,6 +86,8 @@ def _all_to_sharded(segments):
86
representation becomes a sharded representation."""
87
88
def _shard_fn(path, weight):
89
+ if path.endswith("bias"):
90
+ return -1, segments
91
return max(weight.ndim - 2, 0), segments
92
93
return _shard_fn
0 commit comments