Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions mlx_lm/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mlx.nn.layers.distributed import shard_inplace, shard_linear, sum_gradients

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .pipeline import PipelineMixin
from .switch_layers import SwitchGLU


Expand Down Expand Up @@ -221,9 +222,10 @@ def __call__(
return r


class MiniMaxModel(nn.Module):
class MiniMaxModel(PipelineMixin, nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)

self.layers = [
Expand All @@ -241,15 +243,32 @@ def __call__(
h = self.embed_tokens(inputs)

if cache is None:
cache = [None] * len(self.layers)
cache = [None] * len(self.pipeline_layers)

mask = create_attention_mask(h, cache[0])

for layer, c in zip(self.layers, cache):
if self.pipeline_rank < self.pipeline_size - 1:
h = mx.distributed.recv_like(h, self.pipeline_rank + 1)

for layer, c in zip(self.pipeline_layers, cache):
h = layer(h, mask, c)

if self.pipeline_rank != 0:
h = mx.distributed.send(h, (self.pipeline_rank - 1) % self.pipeline_size)

if self.pipeline_size > 1:
h = mx.distributed.all_gather(h)[: h.shape[0]]

return self.norm(h)

def make_cache(self):
from .cache import KVCache

return [
KVCache()
for _ in self.pipeline_layers
]


class Model(nn.Module):
def __init__(self, args: ModelArgs):
Expand Down Expand Up @@ -373,6 +392,12 @@ def shard(self, group: Optional[mx.distributed.Group] = None):
)
layer.block_sparse_moe.sharding_group = group

def make_cache(self):
if hasattr(self.model, "make_cache"):
return self.model.make_cache()
from .cache import KVCache
return [KVCache() for _ in self.model.layers]

@property
def layers(self):
return self.model.layers
Expand Down
39 changes: 34 additions & 5 deletions mlx_lm/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .activations import swiglu
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .pipeline import PipelineMixin
from .switch_layers import SwitchGLU


Expand Down Expand Up @@ -171,7 +172,7 @@ def __call__(
return out


class Qwen3MoeModel(nn.Module):
class Qwen3MoeModel(PipelineMixin, nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
Expand All @@ -197,15 +198,37 @@ def __call__(
h = self.embed_tokens(inputs)

if cache is None:
cache = [None] * len(self.layers)
cache = [None] * len(self.pipeline_layers)

mask = create_attention_mask(h, cache[0])
mask = create_attention_mask(h, cache[0], return_array=True)

for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# In pipeline parallel, receive hidden state from the next rank
# (which processed earlier layers)
if self.pipeline_rank < self.pipeline_size - 1:
h = mx.distributed.recv_like(h, self.pipeline_rank + 1)

for layer, c in zip(self.pipeline_layers, cache):
h = layer(h, mask, cache=c)

# Send hidden state to the previous rank (which has later layers)
if self.pipeline_rank != 0:
h = mx.distributed.send(h, (self.pipeline_rank - 1) % self.pipeline_size)

# All ranks need the final output for lm_head
if self.pipeline_size > 1:
h = mx.distributed.all_gather(h)[: h.shape[0]]

return self.norm(h)

def make_cache(self):
"""Return KV cache entries only for this rank's layers."""
from .cache import KVCache

return [
KVCache()
for _ in self.pipeline_layers
]


class Model(nn.Module):
def __init__(self, args: ModelArgs):
Expand Down Expand Up @@ -254,6 +277,12 @@ def predicate(path, _):

return predicate

def make_cache(self):
if hasattr(self.model, "make_cache"):
return self.model.make_cache()
from .cache import KVCache
return [KVCache() for _ in self.model.layers]

@property
def layers(self):
return self.model.layers