Skip to content

Commit 704bab7

Browse files
committed
import shard_linear
1 parent fb16039 commit 704bab7

1 file changed

Lines changed: 8 additions & 21 deletions

File tree

llms/llama/llama.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import mlx.core as mx
1212
import mlx.nn as nn
13+
from mlx.nn.layers.distributed import shard_linear
1314
from mlx.utils import tree_unflatten
1415
from sentencepiece import SentencePieceProcessor
1516

@@ -56,18 +57,10 @@ def shard(self, group: mx.distributed.Group):
5657
self.n_kv_heads = self.n_kv_heads // group.size()
5758
self.repeats = self.n_heads // self.n_kv_heads
5859

59-
self.wq = nn.layers.distributed.shard_linear(
60-
self.wq, "all-to-sharded", group=group
61-
)
62-
self.wk = nn.layers.distributed.shard_linear(
63-
self.wk, "all-to-sharded", group=group
64-
)
65-
self.wv = nn.layers.distributed.shard_linear(
66-
self.wv, "all-to-sharded", group=group
67-
)
68-
self.wo = nn.layers.distributed.shard_linear(
69-
self.wo, "sharded-to-all", group=group
70-
)
60+
self.wq = shard_linear(self.wq, "all-to-sharded", group=group)
61+
self.wk = shard_linear(self.wk, "all-to-sharded", group=group)
62+
self.wv = shard_linear(self.wv, "all-to-sharded", group=group)
63+
self.wo = shard_linear(self.wo, "sharded-to-all", group=group)
7164

7265
def __call__(
7366
self,
@@ -117,15 +110,9 @@ def __init__(self, args: ModelArgs):
117110
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
118111

119112
def shard(self, group: mx.distributed.Group):
120-
self.w1 = nn.layers.distributed.shard_linear(
121-
self.w1, "all-to-sharded", group=group
122-
)
123-
self.w2 = nn.layers.distributed.shard_linear(
124-
self.w2, "sharded-to-all", group=group
125-
)
126-
self.w3 = nn.layers.distributed.shard_linear(
127-
self.w3, "all-to-sharded", group=group
128-
)
113+
self.w1 = shard_linear(self.w1, "all-to-sharded", group=group)
114+
self.w2 = shard_linear(self.w2, "sharded-to-all", group=group)
115+
self.w3 = shard_linear(self.w3, "all-to-sharded", group=group)
129116

130117
def __call__(self, x) -> mx.array:
131118
return self.w2(nn.silu(self.w1(x)) * self.w3(x))

0 commit comments

Comments
 (0)