|
10 | 10 |
|
11 | 11 | import mlx.core as mx |
12 | 12 | import mlx.nn as nn |
| 13 | +from mlx.nn.layers.distributed import shard_linear |
13 | 14 | from mlx.utils import tree_unflatten |
14 | 15 | from sentencepiece import SentencePieceProcessor |
15 | 16 |
|
@@ -56,18 +57,10 @@ def shard(self, group: mx.distributed.Group): |
56 | 57 | self.n_kv_heads = self.n_kv_heads // group.size() |
57 | 58 | self.repeats = self.n_heads // self.n_kv_heads |
58 | 59 |
|
59 | | - self.wq = nn.layers.distributed.shard_linear( |
60 | | - self.wq, "all-to-sharded", group=group |
61 | | - ) |
62 | | - self.wk = nn.layers.distributed.shard_linear( |
63 | | - self.wk, "all-to-sharded", group=group |
64 | | - ) |
65 | | - self.wv = nn.layers.distributed.shard_linear( |
66 | | - self.wv, "all-to-sharded", group=group |
67 | | - ) |
68 | | - self.wo = nn.layers.distributed.shard_linear( |
69 | | - self.wo, "sharded-to-all", group=group |
70 | | - ) |
| 60 | + self.wq = shard_linear(self.wq, "all-to-sharded", group=group) |
| 61 | + self.wk = shard_linear(self.wk, "all-to-sharded", group=group) |
| 62 | + self.wv = shard_linear(self.wv, "all-to-sharded", group=group) |
| 63 | + self.wo = shard_linear(self.wo, "sharded-to-all", group=group) |
71 | 64 |
|
72 | 65 | def __call__( |
73 | 66 | self, |
@@ -117,15 +110,9 @@ def __init__(self, args: ModelArgs): |
117 | 110 | self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) |
118 | 111 |
|
119 | 112 | def shard(self, group: mx.distributed.Group): |
120 | | - self.w1 = nn.layers.distributed.shard_linear( |
121 | | - self.w1, "all-to-sharded", group=group |
122 | | - ) |
123 | | - self.w2 = nn.layers.distributed.shard_linear( |
124 | | - self.w2, "sharded-to-all", group=group |
125 | | - ) |
126 | | - self.w3 = nn.layers.distributed.shard_linear( |
127 | | - self.w3, "all-to-sharded", group=group |
128 | | - ) |
| 113 | + self.w1 = shard_linear(self.w1, "all-to-sharded", group=group) |
| 114 | + self.w2 = shard_linear(self.w2, "sharded-to-all", group=group) |
| 115 | + self.w3 = shard_linear(self.w3, "all-to-sharded", group=group) |
129 | 116 |
|
130 | 117 | def __call__(self, x) -> mx.array: |
131 | 118 | return self.w2(nn.silu(self.w1(x)) * self.w3(x)) |
|
0 commit comments